@@ -28,7 +28,7 @@ import (
2828 "sqlflow.org/gomaxcompute"
2929)
3030
31- type xgboostFiller struct {
31+ type antXGBoostFiller struct {
3232 ModelPath string
3333 xgLearningFields
3434 xgColumnFields
@@ -164,15 +164,15 @@ func xgMultiSparseError(colNames []string) error {
164164}
165165
166166func xgUnknownFCError (kw string ) error {
167- return fmt .Errorf ("xgUnknownFCError: feature column keyword(`%s`) is not supported by xgboost engine" , kw )
167+ return fmt .Errorf ("xgUnknownFCError: feature column keyword(`%s`) is not supported by ant- xgboost engine" , kw )
168168}
169169
170170func xgUnsupportedColTagError () error {
171- return fmt .Errorf ("xgUnsupportedColTagError: valid column tags of xgboost engine([feature_columns, group, weight])" )
171+ return fmt .Errorf ("xgUnsupportedColTagError: valid column tags of ant- xgboost engine([feature_columns, group, weight])" )
172172}
173173
174- func uIntPartial (key string , ptrFn func (* xgboostFiller ) * uint ) func (* map [string ][]string , * xgboostFiller ) error {
175- return func (a * map [string ][]string , r * xgboostFiller ) error {
174+ func uIntPartial (key string , ptrFn func (* antXGBoostFiller ) * uint ) func (* map [string ][]string , * antXGBoostFiller ) error {
175+ return func (a * map [string ][]string , r * antXGBoostFiller ) error {
176176 // xgParseAttr will ensure the key is existing in map
177177 val , _ := (* a )[key ]
178178 if len (val ) != 1 {
@@ -190,8 +190,8 @@ func uIntPartial(key string, ptrFn func(*xgboostFiller) *uint) func(*map[string]
190190 }
191191}
192192
193- func fp32Partial (key string , ptrFn func (* xgboostFiller ) * float32 ) func (* map [string ][]string , * xgboostFiller ) error {
194- return func (a * map [string ][]string , r * xgboostFiller ) error {
193+ func fp32Partial (key string , ptrFn func (* antXGBoostFiller ) * float32 ) func (* map [string ][]string , * antXGBoostFiller ) error {
194+ return func (a * map [string ][]string , r * antXGBoostFiller ) error {
195195 // xgParseAttr will ensure the key is existing in map
196196 val , _ := (* a )[key ]
197197 if len (val ) != 1 {
@@ -209,8 +209,8 @@ func fp32Partial(key string, ptrFn func(*xgboostFiller) *float32) func(*map[stri
209209 }
210210}
211211
212- func boolPartial (key string , ptrFn func (* xgboostFiller ) * bool ) func (* map [string ][]string , * xgboostFiller ) error {
213- return func (a * map [string ][]string , r * xgboostFiller ) error {
212+ func boolPartial (key string , ptrFn func (* antXGBoostFiller ) * bool ) func (* map [string ][]string , * antXGBoostFiller ) error {
213+ return func (a * map [string ][]string , r * antXGBoostFiller ) error {
214214 // xgParseAttr will ensure the key is existing in map
215215 val , _ := (* a )[key ]
216216 if len (val ) != 1 {
@@ -227,8 +227,8 @@ func boolPartial(key string, ptrFn func(*xgboostFiller) *bool) func(*map[string]
227227 }
228228}
229229
230- func strPartial (key string , ptrFn func (* xgboostFiller ) * string ) func (* map [string ][]string , * xgboostFiller ) error {
231- return func (a * map [string ][]string , r * xgboostFiller ) error {
230+ func strPartial (key string , ptrFn func (* antXGBoostFiller ) * string ) func (* map [string ][]string , * antXGBoostFiller ) error {
231+ return func (a * map [string ][]string , r * antXGBoostFiller ) error {
232232 // xgParseAttr will ensure the key is existing in map
233233 val , _ := (* a )[key ]
234234 if len (val ) != 1 {
@@ -244,8 +244,8 @@ func strPartial(key string, ptrFn func(*xgboostFiller) *string) func(*map[string
244244 }
245245}
246246
247- func sListPartial (key string , ptrFn func (* xgboostFiller ) * []string ) func (* map [string ][]string , * xgboostFiller ) error {
248- return func (a * map [string ][]string , r * xgboostFiller ) error {
247+ func sListPartial (key string , ptrFn func (* antXGBoostFiller ) * []string ) func (* map [string ][]string , * antXGBoostFiller ) error {
248+ return func (a * map [string ][]string , r * antXGBoostFiller ) error {
249249 // xgParseAttr will ensure the key is existing in map
250250 val , _ := (* a )[key ]
251251 strListPtr := ptrFn (r )
@@ -258,48 +258,48 @@ func sListPartial(key string, ptrFn func(*xgboostFiller) *[]string) func(*map[st
258258 }
259259}
260260
261- var xgbTrainAttrSetterMap = map [string ]func (* map [string ][]string , * xgboostFiller ) error {
261+ var xgbTrainAttrSetterMap = map [string ]func (* map [string ][]string , * antXGBoostFiller ) error {
262262 // booster params
263- "train.objective" : strPartial ("train.objective" , func (r * xgboostFiller ) * string { return & (r .Objective ) }),
264- "train.eval_metric" : strPartial ("train.eval_metric" , func (r * xgboostFiller ) * string { return & (r .EvalMetric ) }),
265- "train.booster" : strPartial ("train.booster" , func (r * xgboostFiller ) * string { return & (r .Booster ) }),
266- "train.seed" : uIntPartial ("train.seed" , func (r * xgboostFiller ) * uint { return & (r .Seed ) }),
267- "train.num_class" : uIntPartial ("train.num_class" , func (r * xgboostFiller ) * uint { return & (r .NumClass ) }),
268- "train.eta" : fp32Partial ("train.eta" , func (r * xgboostFiller ) * float32 { return & (r .Eta ) }),
269- "train.gamma" : fp32Partial ("train.gamma" , func (r * xgboostFiller ) * float32 { return & (r .Gamma ) }),
270- "train.max_depth" : uIntPartial ("train.max_depth" , func (r * xgboostFiller ) * uint { return & (r .MaxDepth ) }),
271- "train.min_child_weight" : uIntPartial ("train.min_child_weight" , func (r * xgboostFiller ) * uint { return & (r .MinChildWeight ) }),
272- "train.subsample" : fp32Partial ("train.subsample" , func (r * xgboostFiller ) * float32 { return & (r .Subsample ) }),
273- "train.colsample_bytree" : fp32Partial ("train.colsample_bytree" , func (r * xgboostFiller ) * float32 { return & (r .ColSampleByTree ) }),
274- "train.colsample_bylevel" : fp32Partial ("train.colsample_bylevel" , func (r * xgboostFiller ) * float32 { return & (r .ColSampleByLevel ) }),
275- "train.colsample_bynode" : fp32Partial ("train.colsample_bynode" , func (r * xgboostFiller ) * float32 { return & (r .ColSampleByNode ) }),
276- "train.lambda" : fp32Partial ("train.lambda" , func (r * xgboostFiller ) * float32 { return & (r .Lambda ) }),
277- "train.alpha" : fp32Partial ("train.alpha" , func (r * xgboostFiller ) * float32 { return & (r .Alpha ) }),
278- "train.tree_method" : strPartial ("train.tree_method" , func (r * xgboostFiller ) * string { return & (r .TreeMethod ) }),
279- "train.sketch_eps" : fp32Partial ("train.sketch_eps" , func (r * xgboostFiller ) * float32 { return & (r .SketchEps ) }),
280- "train.scale_pos_weight" : fp32Partial ("train.scale_pos_weight" , func (r * xgboostFiller ) * float32 { return & (r .ScalePosWeight ) }),
281- "train.grow_policy" : strPartial ("train.grow_policy" , func (r * xgboostFiller ) * string { return & (r .GrowPolicy ) }),
282- "train.max_leaves" : uIntPartial ("train.max_leaves" , func (r * xgboostFiller ) * uint { return & (r .MaxLeaves ) }),
283- "train.max_bin" : uIntPartial ("train.max_bin" , func (r * xgboostFiller ) * uint { return & (r .MaxBin ) }),
284- "train.num_parallel_tree" : uIntPartial ("train.num_parallel_tree" , func (r * xgboostFiller ) * uint { return & (r .NumParallelTree ) }),
285- "train.convergence_criteria" : strPartial ("train.convergence_criteria" , func (r * xgboostFiller ) * string { return & (r .ConvergenceCriteria ) }),
286- "train.verbosity" : uIntPartial ("train.verbosity" , func (r * xgboostFiller ) * uint { return & (r .Verbosity ) }),
263+ "train.objective" : strPartial ("train.objective" , func (r * antXGBoostFiller ) * string { return & (r .Objective ) }),
264+ "train.eval_metric" : strPartial ("train.eval_metric" , func (r * antXGBoostFiller ) * string { return & (r .EvalMetric ) }),
265+ "train.booster" : strPartial ("train.booster" , func (r * antXGBoostFiller ) * string { return & (r .Booster ) }),
266+ "train.seed" : uIntPartial ("train.seed" , func (r * antXGBoostFiller ) * uint { return & (r .Seed ) }),
267+ "train.num_class" : uIntPartial ("train.num_class" , func (r * antXGBoostFiller ) * uint { return & (r .NumClass ) }),
268+ "train.eta" : fp32Partial ("train.eta" , func (r * antXGBoostFiller ) * float32 { return & (r .Eta ) }),
269+ "train.gamma" : fp32Partial ("train.gamma" , func (r * antXGBoostFiller ) * float32 { return & (r .Gamma ) }),
270+ "train.max_depth" : uIntPartial ("train.max_depth" , func (r * antXGBoostFiller ) * uint { return & (r .MaxDepth ) }),
271+ "train.min_child_weight" : uIntPartial ("train.min_child_weight" , func (r * antXGBoostFiller ) * uint { return & (r .MinChildWeight ) }),
272+ "train.subsample" : fp32Partial ("train.subsample" , func (r * antXGBoostFiller ) * float32 { return & (r .Subsample ) }),
273+ "train.colsample_bytree" : fp32Partial ("train.colsample_bytree" , func (r * antXGBoostFiller ) * float32 { return & (r .ColSampleByTree ) }),
274+ "train.colsample_bylevel" : fp32Partial ("train.colsample_bylevel" , func (r * antXGBoostFiller ) * float32 { return & (r .ColSampleByLevel ) }),
275+ "train.colsample_bynode" : fp32Partial ("train.colsample_bynode" , func (r * antXGBoostFiller ) * float32 { return & (r .ColSampleByNode ) }),
276+ "train.lambda" : fp32Partial ("train.lambda" , func (r * antXGBoostFiller ) * float32 { return & (r .Lambda ) }),
277+ "train.alpha" : fp32Partial ("train.alpha" , func (r * antXGBoostFiller ) * float32 { return & (r .Alpha ) }),
278+ "train.tree_method" : strPartial ("train.tree_method" , func (r * antXGBoostFiller ) * string { return & (r .TreeMethod ) }),
279+ "train.sketch_eps" : fp32Partial ("train.sketch_eps" , func (r * antXGBoostFiller ) * float32 { return & (r .SketchEps ) }),
280+ "train.scale_pos_weight" : fp32Partial ("train.scale_pos_weight" , func (r * antXGBoostFiller ) * float32 { return & (r .ScalePosWeight ) }),
281+ "train.grow_policy" : strPartial ("train.grow_policy" , func (r * antXGBoostFiller ) * string { return & (r .GrowPolicy ) }),
282+ "train.max_leaves" : uIntPartial ("train.max_leaves" , func (r * antXGBoostFiller ) * uint { return & (r .MaxLeaves ) }),
283+ "train.max_bin" : uIntPartial ("train.max_bin" , func (r * antXGBoostFiller ) * uint { return & (r .MaxBin ) }),
284+ "train.num_parallel_tree" : uIntPartial ("train.num_parallel_tree" , func (r * antXGBoostFiller ) * uint { return & (r .NumParallelTree ) }),
285+ "train.convergence_criteria" : strPartial ("train.convergence_criteria" , func (r * antXGBoostFiller ) * string { return & (r .ConvergenceCriteria ) }),
286+ "train.verbosity" : uIntPartial ("train.verbosity" , func (r * antXGBoostFiller ) * uint { return & (r .Verbosity ) }),
287287 // xgboost train controllers
288- "train.num_round" : uIntPartial ("train.num_round" , func (r * xgboostFiller ) * uint { return & (r .NumRound ) }),
289- "train.auto_train" : boolPartial ("train.auto_train" , func (r * xgboostFiller ) * bool { return & (r .AutoTrain ) }),
288+ "train.num_round" : uIntPartial ("train.num_round" , func (r * antXGBoostFiller ) * uint { return & (r .NumRound ) }),
289+ "train.auto_train" : boolPartial ("train.auto_train" , func (r * antXGBoostFiller ) * bool { return & (r .AutoTrain ) }),
290290 // Label, Group, Weight and xgFeatureFields are parsed from columnClause
291291}
292292
293- var xgbPredAttrSetterMap = map [string ]func (* map [string ][]string , * xgboostFiller ) error {
293+ var xgbPredAttrSetterMap = map [string ]func (* map [string ][]string , * antXGBoostFiller ) error {
294294 // xgboost output columns (for prediction)
295- "pred.append_columns" : sListPartial ("pred.append_columns" , func (r * xgboostFiller ) * []string { return & (r .AppendColumns ) }),
296- "pred.prob_column" : strPartial ("pred.prob_column" , func (r * xgboostFiller ) * string { return & (r .ProbColumn ) }),
297- "pred.detail_column" : strPartial ("pred.detail_column" , func (r * xgboostFiller ) * string { return & (r .DetailColumn ) }),
298- "pred.encoding_column" : strPartial ("pred.encoding_column" , func (r * xgboostFiller ) * string { return & (r .EncodingColumn ) }),
295+ "pred.append_columns" : sListPartial ("pred.append_columns" , func (r * antXGBoostFiller ) * []string { return & (r .AppendColumns ) }),
296+ "pred.prob_column" : strPartial ("pred.prob_column" , func (r * antXGBoostFiller ) * string { return & (r .ProbColumn ) }),
297+ "pred.detail_column" : strPartial ("pred.detail_column" , func (r * antXGBoostFiller ) * string { return & (r .DetailColumn ) }),
298+ "pred.encoding_column" : strPartial ("pred.encoding_column" , func (r * antXGBoostFiller ) * string { return & (r .EncodingColumn ) }),
299299 // Label, Group, Weight and xgFeatureFields are parsed from columnClause
300300}
301301
302- func xgParseAttr (pr * extendedSelect , r * xgboostFiller ) error {
302+ func xgParseAttr (pr * extendedSelect , r * antXGBoostFiller ) error {
303303 var rawAttrs map [string ]* expr
304304 if pr .train {
305305 rawAttrs = pr .trainAttrs
@@ -324,8 +324,8 @@ func xgParseAttr(pr *extendedSelect, r *xgboostFiller) error {
324324 }
325325 }
326326
327- // fill xgboostFiller with attrs
328- var setterMap map [string ]func (* map [string ][]string , * xgboostFiller ) error
327+ // fill antXGBoostFiller with attrs
328+ var setterMap map [string ]func (* map [string ][]string , * antXGBoostFiller ) error
329329 if pr .train {
330330 setterMap = xgbTrainAttrSetterMap
331331 } else {
@@ -358,7 +358,7 @@ func xgParseAttr(pr *extendedSelect, r *xgboostFiller) error {
358358// data example: COLUMN SPARSE("0:1.5 1:100.1f 11:-1.2", [20], " ")
359359// 2. tf feature columns
360360// Roughly same as TFEstimator, except output shape of feaColumns are required to be 1-dim.
361- func parseFeatureColumns (columns * exprlist , r * xgboostFiller ) error {
361+ func parseFeatureColumns (columns * exprlist , r * antXGBoostFiller ) error {
362362 feaCols , colSpecs , err := resolveTrainColumns (columns )
363363 if err != nil {
364364 return err
@@ -379,7 +379,7 @@ func parseFeatureColumns(columns *exprlist, r *xgboostFiller) error {
379379
380380// parseSparseKeyValueFeatures, parse features which is identified by `SPARSE`.
381381// ex: SPARSE(col1, [100], comma)
382- func parseSparseKeyValueFeatures (colSpecs []* columnSpec , r * xgboostFiller ) error {
382+ func parseSparseKeyValueFeatures (colSpecs []* columnSpec , r * antXGBoostFiller ) error {
383383 var colNames []string
384384 for _ , spec := range colSpecs {
385385 colNames = append (colNames , spec .ColumnName )
@@ -425,7 +425,7 @@ func isSimpleColumn(col featureColumn) bool {
425425 return false
426426}
427427
428- func parseDenseFeatures (feaCols []featureColumn , r * xgboostFiller ) error {
428+ func parseDenseFeatures (feaCols []featureColumn , r * antXGBoostFiller ) error {
429429 allSimpleCol := true
430430 for _ , col := range feaCols {
431431 if allSimpleCol && ! isSimpleColumn (col ) {
@@ -511,7 +511,7 @@ func parseSimpleColumn(field string, columns *exprlist) (*xgFeatureMeta, error)
511511 return fm , nil
512512}
513513
514- func xgParseColumns (pr * extendedSelect , filler * xgboostFiller ) error {
514+ func xgParseColumns (pr * extendedSelect , filler * antXGBoostFiller ) error {
515515 for target , columns := range pr .columns {
516516 switch target {
517517 case "feature_columns" :
@@ -553,7 +553,7 @@ func xgParseColumns(pr *extendedSelect, filler *xgboostFiller) error {
553553 return nil
554554}
555555
556- func xgParseEstimator (pr * extendedSelect , filler * xgboostFiller ) error {
556+ func xgParseEstimator (pr * extendedSelect , filler * antXGBoostFiller ) error {
557557 switch strings .ToUpper (pr .estimator ) {
558558 case "XGBOOST.ESTIMATOR" :
559559 if len (filler .Objective ) == 0 {
@@ -590,8 +590,8 @@ func xgParseEstimator(pr *extendedSelect, filler *xgboostFiller) error {
590590 return nil
591591}
592592
593- func newXGBoostFiller (pr * extendedSelect , ds * trainAndValDataset , db * DB ) (* xgboostFiller , error ) {
594- filler := & xgboostFiller {
593+ func newAntXGBoostFiller (pr * extendedSelect , ds * trainAndValDataset , db * DB ) (* antXGBoostFiller , error ) {
594+ filler := & antXGBoostFiller {
595595 ModelPath : pr .save ,
596596 }
597597 filler .IsTrain = pr .train
@@ -720,7 +720,7 @@ func xgFillDatabaseInfo(r *xgDataSourceFields, db *DB) error {
720720 return nil
721721}
722722
723- func xgCreatePredictionTable (pr * extendedSelect , r * xgboostFiller , db * DB ) error {
723+ func xgCreatePredictionTable (pr * extendedSelect , r * antXGBoostFiller , db * DB ) error {
724724 dropStmt := fmt .Sprintf ("drop table if exists %s;" , r .OutputTable )
725725 if _ , e := db .Exec (dropStmt ); e != nil {
726726 return fmt .Errorf ("failed executing %s: %q" , dropStmt , e )
@@ -791,7 +791,7 @@ func xgCreatePredictionTable(pr *extendedSelect, r *xgboostFiller, db *DB) error
791791}
792792
793793func genXG (w io.Writer , pr * extendedSelect , ds * trainAndValDataset , fts fieldTypes , db * DB ) error {
794- r , e := newXGBoostFiller (pr , ds , db )
794+ r , e := newAntXGBoostFiller (pr , ds , db )
795795 if e != nil {
796796 return e
797797 }
@@ -808,7 +808,7 @@ var xgTemplate = template.Must(template.New("codegenXG").Parse(xgTemplateText))
808808
809809const xgTemplateText = `
810810from launcher.config_fields import JobType
811- from sqlflow_submitter.xgboost import run_with_sqlflow
811+ from sqlflow_submitter.ant_xgboost import run_with_sqlflow
812812
813813{{if .IsTrain}}
814814mode = JobType.TRAIN
0 commit comments