Skip to content

Commit 9f9bba7

Browse files
authored
rename xgboost to ant_xgboost (#754)
* rename xgboost to ant_xgboost * update * renmae xgboosterror to antxgboosterror * fix ci * fix ci * fix ci * train xgboost which can be consistant with xgboost and ant-xgboost
1 parent b7ce643 commit 9f9bba7

13 files changed

+108
-108
lines changed

doc/xgboost_on_sqlflow_design.md renamed to doc/ant-xgboost_on_sqlflow_design.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
# _Design:_ xgboost on sqlflow
1+
# _Design:_ ant-xgboost on sqlflow
22

33
## Overview
44

5-
This is a design doc about why and how to support running xgboost via sqlflow as a machine learning estimator.
5+
This is a design doc about why and how to support running ant-xgboost via sqlflow as a machine learning estimator.
66

77
We propose to build a lightweight python template for xgboost on basis of `xgblauncher`,
88
an incubating xgboost wrapper in [ant-xgboost](https://github.com/alipay/ant-xgboost).

scripts/test_e2e.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ export PYTHONPATH=$GOPATH/src/github.com/sql-machine-learning/sqlflow/sql/python
4848
sqlflowserver --datasource=${DATASOURCE} &
4949
# e2e test for standard SQL
5050
SQLFLOW_SERVER=localhost:50051 ipython sql/python/test_magic.py
51-
# e2e test for xgboost train and prediciton SQL.
52-
SQLFLOW_SERVER=localhost:50051 ipython sql/python/test_magic_xgboost.py
51+
# e2e test for ant-xgboost train and prediciton SQL.
52+
SQLFLOW_SERVER=localhost:50051 ipython sql/python/test_magic_ant_xgboost.py
5353
# TODO(terrytangyuan): Enable this when ElasticDL is open sourced
5454
# e2e test for ElasticDL SQL
5555
# export SQLFLOW_submitter=elasticdl

sql/codegen_analyze.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func readFeatureNames(pr *extendedSelect, db *DB) ([]string, string, error) {
4141
if strings.HasPrefix(strings.ToUpper(pr.estimator), `XGBOOST.`) {
4242
// TODO(weiguo): It's a quick way to read column and label names from
4343
// xgboost.*, but too heavy.
44-
xgbFiller, err := newXGBoostFiller(pr, nil, db)
44+
xgbFiller, err := newAntXGBoostFiller(pr, nil, db)
4545
if err != nil {
4646
return nil, "", err
4747
}

sql/codegen_xgboost.go renamed to sql/codegen_ant_xgboost.go

Lines changed: 58 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import (
2828
"sqlflow.org/gomaxcompute"
2929
)
3030

31-
type xgboostFiller struct {
31+
type antXGBoostFiller struct {
3232
ModelPath string
3333
xgLearningFields
3434
xgColumnFields
@@ -164,15 +164,15 @@ func xgMultiSparseError(colNames []string) error {
164164
}
165165

166166
func xgUnknownFCError(kw string) error {
167-
return fmt.Errorf("xgUnknownFCError: feature column keyword(`%s`) is not supported by xgboost engine", kw)
167+
return fmt.Errorf("xgUnknownFCError: feature column keyword(`%s`) is not supported by ant-xgboost engine", kw)
168168
}
169169

170170
func xgUnsupportedColTagError() error {
171-
return fmt.Errorf("xgUnsupportedColTagError: valid column tags of xgboost engine([feature_columns, group, weight])")
171+
return fmt.Errorf("xgUnsupportedColTagError: valid column tags of ant-xgboost engine([feature_columns, group, weight])")
172172
}
173173

174-
func uIntPartial(key string, ptrFn func(*xgboostFiller) *uint) func(*map[string][]string, *xgboostFiller) error {
175-
return func(a *map[string][]string, r *xgboostFiller) error {
174+
func uIntPartial(key string, ptrFn func(*antXGBoostFiller) *uint) func(*map[string][]string, *antXGBoostFiller) error {
175+
return func(a *map[string][]string, r *antXGBoostFiller) error {
176176
// xgParseAttr will ensure the key is existing in map
177177
val, _ := (*a)[key]
178178
if len(val) != 1 {
@@ -190,8 +190,8 @@ func uIntPartial(key string, ptrFn func(*xgboostFiller) *uint) func(*map[string]
190190
}
191191
}
192192

193-
func fp32Partial(key string, ptrFn func(*xgboostFiller) *float32) func(*map[string][]string, *xgboostFiller) error {
194-
return func(a *map[string][]string, r *xgboostFiller) error {
193+
func fp32Partial(key string, ptrFn func(*antXGBoostFiller) *float32) func(*map[string][]string, *antXGBoostFiller) error {
194+
return func(a *map[string][]string, r *antXGBoostFiller) error {
195195
// xgParseAttr will ensure the key is existing in map
196196
val, _ := (*a)[key]
197197
if len(val) != 1 {
@@ -209,8 +209,8 @@ func fp32Partial(key string, ptrFn func(*xgboostFiller) *float32) func(*map[stri
209209
}
210210
}
211211

212-
func boolPartial(key string, ptrFn func(*xgboostFiller) *bool) func(*map[string][]string, *xgboostFiller) error {
213-
return func(a *map[string][]string, r *xgboostFiller) error {
212+
func boolPartial(key string, ptrFn func(*antXGBoostFiller) *bool) func(*map[string][]string, *antXGBoostFiller) error {
213+
return func(a *map[string][]string, r *antXGBoostFiller) error {
214214
// xgParseAttr will ensure the key is existing in map
215215
val, _ := (*a)[key]
216216
if len(val) != 1 {
@@ -227,8 +227,8 @@ func boolPartial(key string, ptrFn func(*xgboostFiller) *bool) func(*map[string]
227227
}
228228
}
229229

230-
func strPartial(key string, ptrFn func(*xgboostFiller) *string) func(*map[string][]string, *xgboostFiller) error {
231-
return func(a *map[string][]string, r *xgboostFiller) error {
230+
func strPartial(key string, ptrFn func(*antXGBoostFiller) *string) func(*map[string][]string, *antXGBoostFiller) error {
231+
return func(a *map[string][]string, r *antXGBoostFiller) error {
232232
// xgParseAttr will ensure the key is existing in map
233233
val, _ := (*a)[key]
234234
if len(val) != 1 {
@@ -244,8 +244,8 @@ func strPartial(key string, ptrFn func(*xgboostFiller) *string) func(*map[string
244244
}
245245
}
246246

247-
func sListPartial(key string, ptrFn func(*xgboostFiller) *[]string) func(*map[string][]string, *xgboostFiller) error {
248-
return func(a *map[string][]string, r *xgboostFiller) error {
247+
func sListPartial(key string, ptrFn func(*antXGBoostFiller) *[]string) func(*map[string][]string, *antXGBoostFiller) error {
248+
return func(a *map[string][]string, r *antXGBoostFiller) error {
249249
// xgParseAttr will ensure the key is existing in map
250250
val, _ := (*a)[key]
251251
strListPtr := ptrFn(r)
@@ -258,48 +258,48 @@ func sListPartial(key string, ptrFn func(*xgboostFiller) *[]string) func(*map[st
258258
}
259259
}
260260

261-
var xgbTrainAttrSetterMap = map[string]func(*map[string][]string, *xgboostFiller) error{
261+
var xgbTrainAttrSetterMap = map[string]func(*map[string][]string, *antXGBoostFiller) error{
262262
// booster params
263-
"train.objective": strPartial("train.objective", func(r *xgboostFiller) *string { return &(r.Objective) }),
264-
"train.eval_metric": strPartial("train.eval_metric", func(r *xgboostFiller) *string { return &(r.EvalMetric) }),
265-
"train.booster": strPartial("train.booster", func(r *xgboostFiller) *string { return &(r.Booster) }),
266-
"train.seed": uIntPartial("train.seed", func(r *xgboostFiller) *uint { return &(r.Seed) }),
267-
"train.num_class": uIntPartial("train.num_class", func(r *xgboostFiller) *uint { return &(r.NumClass) }),
268-
"train.eta": fp32Partial("train.eta", func(r *xgboostFiller) *float32 { return &(r.Eta) }),
269-
"train.gamma": fp32Partial("train.gamma", func(r *xgboostFiller) *float32 { return &(r.Gamma) }),
270-
"train.max_depth": uIntPartial("train.max_depth", func(r *xgboostFiller) *uint { return &(r.MaxDepth) }),
271-
"train.min_child_weight": uIntPartial("train.min_child_weight", func(r *xgboostFiller) *uint { return &(r.MinChildWeight) }),
272-
"train.subsample": fp32Partial("train.subsample", func(r *xgboostFiller) *float32 { return &(r.Subsample) }),
273-
"train.colsample_bytree": fp32Partial("train.colsample_bytree", func(r *xgboostFiller) *float32 { return &(r.ColSampleByTree) }),
274-
"train.colsample_bylevel": fp32Partial("train.colsample_bylevel", func(r *xgboostFiller) *float32 { return &(r.ColSampleByLevel) }),
275-
"train.colsample_bynode": fp32Partial("train.colsample_bynode", func(r *xgboostFiller) *float32 { return &(r.ColSampleByNode) }),
276-
"train.lambda": fp32Partial("train.lambda", func(r *xgboostFiller) *float32 { return &(r.Lambda) }),
277-
"train.alpha": fp32Partial("train.alpha", func(r *xgboostFiller) *float32 { return &(r.Alpha) }),
278-
"train.tree_method": strPartial("train.tree_method", func(r *xgboostFiller) *string { return &(r.TreeMethod) }),
279-
"train.sketch_eps": fp32Partial("train.sketch_eps", func(r *xgboostFiller) *float32 { return &(r.SketchEps) }),
280-
"train.scale_pos_weight": fp32Partial("train.scale_pos_weight", func(r *xgboostFiller) *float32 { return &(r.ScalePosWeight) }),
281-
"train.grow_policy": strPartial("train.grow_policy", func(r *xgboostFiller) *string { return &(r.GrowPolicy) }),
282-
"train.max_leaves": uIntPartial("train.max_leaves", func(r *xgboostFiller) *uint { return &(r.MaxLeaves) }),
283-
"train.max_bin": uIntPartial("train.max_bin", func(r *xgboostFiller) *uint { return &(r.MaxBin) }),
284-
"train.num_parallel_tree": uIntPartial("train.num_parallel_tree", func(r *xgboostFiller) *uint { return &(r.NumParallelTree) }),
285-
"train.convergence_criteria": strPartial("train.convergence_criteria", func(r *xgboostFiller) *string { return &(r.ConvergenceCriteria) }),
286-
"train.verbosity": uIntPartial("train.verbosity", func(r *xgboostFiller) *uint { return &(r.Verbosity) }),
263+
"train.objective": strPartial("train.objective", func(r *antXGBoostFiller) *string { return &(r.Objective) }),
264+
"train.eval_metric": strPartial("train.eval_metric", func(r *antXGBoostFiller) *string { return &(r.EvalMetric) }),
265+
"train.booster": strPartial("train.booster", func(r *antXGBoostFiller) *string { return &(r.Booster) }),
266+
"train.seed": uIntPartial("train.seed", func(r *antXGBoostFiller) *uint { return &(r.Seed) }),
267+
"train.num_class": uIntPartial("train.num_class", func(r *antXGBoostFiller) *uint { return &(r.NumClass) }),
268+
"train.eta": fp32Partial("train.eta", func(r *antXGBoostFiller) *float32 { return &(r.Eta) }),
269+
"train.gamma": fp32Partial("train.gamma", func(r *antXGBoostFiller) *float32 { return &(r.Gamma) }),
270+
"train.max_depth": uIntPartial("train.max_depth", func(r *antXGBoostFiller) *uint { return &(r.MaxDepth) }),
271+
"train.min_child_weight": uIntPartial("train.min_child_weight", func(r *antXGBoostFiller) *uint { return &(r.MinChildWeight) }),
272+
"train.subsample": fp32Partial("train.subsample", func(r *antXGBoostFiller) *float32 { return &(r.Subsample) }),
273+
"train.colsample_bytree": fp32Partial("train.colsample_bytree", func(r *antXGBoostFiller) *float32 { return &(r.ColSampleByTree) }),
274+
"train.colsample_bylevel": fp32Partial("train.colsample_bylevel", func(r *antXGBoostFiller) *float32 { return &(r.ColSampleByLevel) }),
275+
"train.colsample_bynode": fp32Partial("train.colsample_bynode", func(r *antXGBoostFiller) *float32 { return &(r.ColSampleByNode) }),
276+
"train.lambda": fp32Partial("train.lambda", func(r *antXGBoostFiller) *float32 { return &(r.Lambda) }),
277+
"train.alpha": fp32Partial("train.alpha", func(r *antXGBoostFiller) *float32 { return &(r.Alpha) }),
278+
"train.tree_method": strPartial("train.tree_method", func(r *antXGBoostFiller) *string { return &(r.TreeMethod) }),
279+
"train.sketch_eps": fp32Partial("train.sketch_eps", func(r *antXGBoostFiller) *float32 { return &(r.SketchEps) }),
280+
"train.scale_pos_weight": fp32Partial("train.scale_pos_weight", func(r *antXGBoostFiller) *float32 { return &(r.ScalePosWeight) }),
281+
"train.grow_policy": strPartial("train.grow_policy", func(r *antXGBoostFiller) *string { return &(r.GrowPolicy) }),
282+
"train.max_leaves": uIntPartial("train.max_leaves", func(r *antXGBoostFiller) *uint { return &(r.MaxLeaves) }),
283+
"train.max_bin": uIntPartial("train.max_bin", func(r *antXGBoostFiller) *uint { return &(r.MaxBin) }),
284+
"train.num_parallel_tree": uIntPartial("train.num_parallel_tree", func(r *antXGBoostFiller) *uint { return &(r.NumParallelTree) }),
285+
"train.convergence_criteria": strPartial("train.convergence_criteria", func(r *antXGBoostFiller) *string { return &(r.ConvergenceCriteria) }),
286+
"train.verbosity": uIntPartial("train.verbosity", func(r *antXGBoostFiller) *uint { return &(r.Verbosity) }),
287287
// xgboost train controllers
288-
"train.num_round": uIntPartial("train.num_round", func(r *xgboostFiller) *uint { return &(r.NumRound) }),
289-
"train.auto_train": boolPartial("train.auto_train", func(r *xgboostFiller) *bool { return &(r.AutoTrain) }),
288+
"train.num_round": uIntPartial("train.num_round", func(r *antXGBoostFiller) *uint { return &(r.NumRound) }),
289+
"train.auto_train": boolPartial("train.auto_train", func(r *antXGBoostFiller) *bool { return &(r.AutoTrain) }),
290290
// Label, Group, Weight and xgFeatureFields are parsed from columnClause
291291
}
292292

293-
var xgbPredAttrSetterMap = map[string]func(*map[string][]string, *xgboostFiller) error{
293+
var xgbPredAttrSetterMap = map[string]func(*map[string][]string, *antXGBoostFiller) error{
294294
// xgboost output columns (for prediction)
295-
"pred.append_columns": sListPartial("pred.append_columns", func(r *xgboostFiller) *[]string { return &(r.AppendColumns) }),
296-
"pred.prob_column": strPartial("pred.prob_column", func(r *xgboostFiller) *string { return &(r.ProbColumn) }),
297-
"pred.detail_column": strPartial("pred.detail_column", func(r *xgboostFiller) *string { return &(r.DetailColumn) }),
298-
"pred.encoding_column": strPartial("pred.encoding_column", func(r *xgboostFiller) *string { return &(r.EncodingColumn) }),
295+
"pred.append_columns": sListPartial("pred.append_columns", func(r *antXGBoostFiller) *[]string { return &(r.AppendColumns) }),
296+
"pred.prob_column": strPartial("pred.prob_column", func(r *antXGBoostFiller) *string { return &(r.ProbColumn) }),
297+
"pred.detail_column": strPartial("pred.detail_column", func(r *antXGBoostFiller) *string { return &(r.DetailColumn) }),
298+
"pred.encoding_column": strPartial("pred.encoding_column", func(r *antXGBoostFiller) *string { return &(r.EncodingColumn) }),
299299
// Label, Group, Weight and xgFeatureFields are parsed from columnClause
300300
}
301301

302-
func xgParseAttr(pr *extendedSelect, r *xgboostFiller) error {
302+
func xgParseAttr(pr *extendedSelect, r *antXGBoostFiller) error {
303303
var rawAttrs map[string]*expr
304304
if pr.train {
305305
rawAttrs = pr.trainAttrs
@@ -324,8 +324,8 @@ func xgParseAttr(pr *extendedSelect, r *xgboostFiller) error {
324324
}
325325
}
326326

327-
// fill xgboostFiller with attrs
328-
var setterMap map[string]func(*map[string][]string, *xgboostFiller) error
327+
// fill antXGBoostFiller with attrs
328+
var setterMap map[string]func(*map[string][]string, *antXGBoostFiller) error
329329
if pr.train {
330330
setterMap = xgbTrainAttrSetterMap
331331
} else {
@@ -358,7 +358,7 @@ func xgParseAttr(pr *extendedSelect, r *xgboostFiller) error {
358358
// data example: COLUMN SPARSE("0:1.5 1:100.1f 11:-1.2", [20], " ")
359359
// 2. tf feature columns
360360
// Roughly same as TFEstimator, except output shape of feaColumns are required to be 1-dim.
361-
func parseFeatureColumns(columns *exprlist, r *xgboostFiller) error {
361+
func parseFeatureColumns(columns *exprlist, r *antXGBoostFiller) error {
362362
feaCols, colSpecs, err := resolveTrainColumns(columns)
363363
if err != nil {
364364
return err
@@ -379,7 +379,7 @@ func parseFeatureColumns(columns *exprlist, r *xgboostFiller) error {
379379

380380
// parseSparseKeyValueFeatures, parse features which is identified by `SPARSE`.
381381
// ex: SPARSE(col1, [100], comma)
382-
func parseSparseKeyValueFeatures(colSpecs []*columnSpec, r *xgboostFiller) error {
382+
func parseSparseKeyValueFeatures(colSpecs []*columnSpec, r *antXGBoostFiller) error {
383383
var colNames []string
384384
for _, spec := range colSpecs {
385385
colNames = append(colNames, spec.ColumnName)
@@ -425,7 +425,7 @@ func isSimpleColumn(col featureColumn) bool {
425425
return false
426426
}
427427

428-
func parseDenseFeatures(feaCols []featureColumn, r *xgboostFiller) error {
428+
func parseDenseFeatures(feaCols []featureColumn, r *antXGBoostFiller) error {
429429
allSimpleCol := true
430430
for _, col := range feaCols {
431431
if allSimpleCol && !isSimpleColumn(col) {
@@ -511,7 +511,7 @@ func parseSimpleColumn(field string, columns *exprlist) (*xgFeatureMeta, error)
511511
return fm, nil
512512
}
513513

514-
func xgParseColumns(pr *extendedSelect, filler *xgboostFiller) error {
514+
func xgParseColumns(pr *extendedSelect, filler *antXGBoostFiller) error {
515515
for target, columns := range pr.columns {
516516
switch target {
517517
case "feature_columns":
@@ -553,7 +553,7 @@ func xgParseColumns(pr *extendedSelect, filler *xgboostFiller) error {
553553
return nil
554554
}
555555

556-
func xgParseEstimator(pr *extendedSelect, filler *xgboostFiller) error {
556+
func xgParseEstimator(pr *extendedSelect, filler *antXGBoostFiller) error {
557557
switch strings.ToUpper(pr.estimator) {
558558
case "XGBOOST.ESTIMATOR":
559559
if len(filler.Objective) == 0 {
@@ -590,8 +590,8 @@ func xgParseEstimator(pr *extendedSelect, filler *xgboostFiller) error {
590590
return nil
591591
}
592592

593-
func newXGBoostFiller(pr *extendedSelect, ds *trainAndValDataset, db *DB) (*xgboostFiller, error) {
594-
filler := &xgboostFiller{
593+
func newAntXGBoostFiller(pr *extendedSelect, ds *trainAndValDataset, db *DB) (*antXGBoostFiller, error) {
594+
filler := &antXGBoostFiller{
595595
ModelPath: pr.save,
596596
}
597597
filler.IsTrain = pr.train
@@ -720,7 +720,7 @@ func xgFillDatabaseInfo(r *xgDataSourceFields, db *DB) error {
720720
return nil
721721
}
722722

723-
func xgCreatePredictionTable(pr *extendedSelect, r *xgboostFiller, db *DB) error {
723+
func xgCreatePredictionTable(pr *extendedSelect, r *antXGBoostFiller, db *DB) error {
724724
dropStmt := fmt.Sprintf("drop table if exists %s;", r.OutputTable)
725725
if _, e := db.Exec(dropStmt); e != nil {
726726
return fmt.Errorf("failed executing %s: %q", dropStmt, e)
@@ -791,7 +791,7 @@ func xgCreatePredictionTable(pr *extendedSelect, r *xgboostFiller, db *DB) error
791791
}
792792

793793
func genXG(w io.Writer, pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *DB) error {
794-
r, e := newXGBoostFiller(pr, ds, db)
794+
r, e := newAntXGBoostFiller(pr, ds, db)
795795
if e != nil {
796796
return e
797797
}
@@ -808,7 +808,7 @@ var xgTemplate = template.Must(template.New("codegenXG").Parse(xgTemplateText))
808808

809809
const xgTemplateText = `
810810
from launcher.config_fields import JobType
811-
from sqlflow_submitter.xgboost import run_with_sqlflow
811+
from sqlflow_submitter.ant_xgboost import run_with_sqlflow
812812
813813
{{if .IsTrain}}
814814
mode = JobType.TRAIN

0 commit comments

Comments
 (0)