Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/design/design_xgboost_on_sqlflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ SELECT * FROM train_table
TRAIN xgboost.gbtree
WITH
objective=multi:softmax,
train.num_round=2,
train.num_boost_round=2,
max_depth=2,
eta=1
LABEL class
Expand Down
70 changes: 26 additions & 44 deletions pkg/sql/codegen/xgboost/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,33 @@ import (
"bytes"
"encoding/json"
"fmt"
"sqlflow.org/sqlflow/pkg/sql/codegen/attribute"
"strings"

"sqlflow.org/sqlflow/pkg/sql/codegen"
)

var attributeChecker = map[string]func(interface{}) error{
"eta": func(x interface{}) error {
switch x.(type) {
case float32, float64:
return nil
default:
return fmt.Errorf("eta should be of type float, received %T", x)
}
},
"num_class": func(x interface{}) error {
switch x.(type) {
case int, int32, int64:
return nil
default:
return fmt.Errorf("num_class should be of type int, received %T", x)
}
},
"train.num_boost_round": func(x interface{}) error {
switch x.(type) {
case int, int32, int64:
return nil
default:
return fmt.Errorf("train.num_boost_round should be of type int, received %T", x)
}
},
"objective": func(x interface{}) error {
if _, ok := x.(string); !ok {
return fmt.Errorf("objective should be of type string, received %T", x)
}
return nil
},
func newFloat32(f float32) *float32 {
return &f
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if returning the address of a parameter is to new a float value. Maybe

return &new(float32)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

func newInt(i int) *int {
return &i
}

// TODO(tony): complete model parameter and training parameter list
// model parameter list: https://xgboost.readthedocs.io/en/latest/parameter.html#general-parameters
// training parameter list: https://github.com/dmlc/xgboost/blob/b61d53447203ca7a321d72f6bdd3f553a3aa06c4/python-package/xgboost/training.py#L115-L117
var attributeDictionary = attribute.Dictionary{
"eta": {attribute.Float, `[default=0.3, alias: learning_rate]
Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features, and eta shrinks the feature weights to make the boosting process more conservative.
range: [0,1]`, attribute.Float32RangeChecker(newFloat32(0), newFloat32(1), true, true)},
"num_class": {attribute.Int, `Number of classes.
range: [1, Infinity]`, attribute.IntRangeChecker(newInt(0), nil, false, false)},
"objective": {attribute.String, `Learning objective`, nil},
"train.num_boost_round": {attribute.Int, `[default=10]
The number of rounds for boosting.
range: [1, Infinity]`, attribute.IntRangeChecker(newInt(0), nil, false, false)},
}

func resolveModelType(estimator string) (string, error) {
Expand All @@ -69,22 +60,13 @@ func resolveModelType(estimator string) (string, error) {
}

func parseAttribute(attrs map[string]interface{}) (map[string]map[string]interface{}, error) {
attrNames := map[string]bool{}
if err := attributeDictionary.Validate(attrs); err != nil {
return nil, err
}

params := map[string]map[string]interface{}{"": {}, "train.": {}}
paramPrefix := []string{"train.", ""} // use slice to assure traverse order
paramPrefix := []string{"train.", ""} // use slice to assure traverse order, this is necessary because all string starts with ""
for key, attr := range attrs {
if _, ok := attrNames[key]; ok {
return nil, fmt.Errorf("duplicated attribute %s", key)
}
attrNames[key] = true
checker, ok := attributeChecker[key]
if !ok {
return nil, fmt.Errorf("unrecognized attribute %v", key)
}
if err := checker(attr); err != nil {
return nil, err
}
for _, pp := range paramPrefix {
if strings.HasPrefix(key, pp) {
params[pp][key[len(pp):]] = attr
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/codegen/xgboost/codegen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ func TestTrain(t *testing.T) {
ValidationSelect: "select * from iris.test;",
Estimator: "xgboost.gbtree",
Attributes: map[string]interface{}{
"train.num_boost_round": 30,
"train.num_boost_round": 10,
"objective": "multi:softprob",
"eta": 3.1,
"eta": float32(0.1),
"num_class": 3},
Features: map[string][]codegen.FeatureColumn{
"feature_columns": {
Expand Down