Skip to content

Commit ccea266

Browse files
authored
generate python feacol code (#2797)
1 parent 2443661 commit ccea266

File tree

9 files changed

+197
-39
lines changed

9 files changed

+197
-39
lines changed

go/codegen/alps/codegen.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ import (
1919
"strings"
2020
"text/template"
2121

22-
"sqlflow.org/sqlflow/go/codegen"
2322
"sqlflow.org/sqlflow/go/codegen/tensorflow"
2423
"sqlflow.org/sqlflow/go/ir"
2524
pb "sqlflow.org/sqlflow/go/proto"
@@ -51,9 +50,9 @@ func Train(trainStmt *ir.TrainStmt, session *pb.Session) (string, error) {
5150

5251
var program bytes.Buffer
5352
var trainTemplate = template.Must(template.New("Train").Funcs(template.FuncMap{
54-
"intArrayToJSONString": codegen.MarshalToJSONString,
55-
"attrToPythonValue": codegen.AttrToPythonValue,
56-
"DTypeToString": codegen.DTypeToString,
53+
"intArrayToJSONString": ir.MarshalToJSONString,
54+
"attrToPythonValue": ir.AttrToPythonValue,
55+
"DTypeToString": ir.DTypeToString,
5756
}).Parse(templateTrain))
5857
if err := trainTemplate.Execute(&program, filler); err != nil {
5958
return "", err

go/codegen/codegen_feature_column.go

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
package codegen
1515

1616
import (
17-
"encoding/json"
1817
"fmt"
19-
"sqlflow.org/sqlflow/go/ir"
2018
"strings"
19+
20+
"sqlflow.org/sqlflow/go/ir"
2121
)
2222

2323
func toModuleDataType(dtype int, module string) (string, error) {
@@ -40,17 +40,11 @@ func isXGBoostModule(module string) bool {
4040
return strings.HasPrefix(module, "xgboost")
4141
}
4242

43-
// MarshalToJSONString converts any data to JSON string.
44-
func MarshalToJSONString(in interface{}) (string, error) {
45-
bytes, err := json.Marshal(in)
46-
return string(bytes), err
47-
}
48-
4943
// GenerateFeatureColumnCode generates feature column code for both TensorFlow and XGBoost models
5044
func GenerateFeatureColumnCode(fc ir.FeatureColumn, module string) (string, error) {
5145
switch c := fc.(type) {
5246
case *ir.NumericColumn:
53-
shapeStr, err := MarshalToJSONString(c.FieldDesc.Shape)
47+
shapeStr, err := ir.MarshalToJSONString(c.FieldDesc.Shape)
5448
if err != nil {
5549
return "", err
5650
}
@@ -63,7 +57,7 @@ func GenerateFeatureColumnCode(fc ir.FeatureColumn, module string) (string, erro
6357
if err != nil {
6458
return "", err
6559
}
66-
boundariesStr, err := MarshalToJSONString(c.Boundaries)
60+
boundariesStr, err := ir.MarshalToJSONString(c.Boundaries)
6761
if err != nil {
6862
return "", nil
6963
}

go/codegen/experimental/xgboost.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import (
2323
"text/template"
2424

2525
"sqlflow.org/sqlflow/go/attribute"
26-
"sqlflow.org/sqlflow/go/codegen"
2726
"sqlflow.org/sqlflow/go/ir"
2827
pb "sqlflow.org/sqlflow/go/proto"
2928
)
@@ -197,11 +196,11 @@ func generateFeatureColumnCode(fcList []ir.FeatureColumn) (string, error) {
197196
}
198197

199198
code := fmt.Sprintf(tmpl, fcTypeName, fd.Name,
200-
strings.ToUpper(codegen.DTypeToString(fd.DType)),
199+
strings.ToUpper(ir.DTypeToString(fd.DType)),
201200
fd.Delimiter,
202-
codegen.AttrToPythonValue(shape),
201+
ir.AttrToPythonValue(shape),
203202
isSparseStr,
204-
codegen.AttrToPythonValue(vocabList))
203+
ir.AttrToPythonValue(vocabList))
205204
fcCodes = append(fcCodes, code)
206205
}
207206

go/codegen/tensorflow/codegen.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,9 @@ func Train(trainStmt *ir.TrainStmt, session *pb.Session) (string, error) {
290290
}
291291
var program bytes.Buffer
292292
var trainTemplate = template.Must(template.New("Train").Funcs(template.FuncMap{
293-
"intArrayToJSONString": codegen.MarshalToJSONString,
294-
"attrToPythonValue": codegen.AttrToPythonValue,
295-
"DTypeToString": codegen.DTypeToString,
293+
"intArrayToJSONString": ir.MarshalToJSONString,
294+
"attrToPythonValue": ir.AttrToPythonValue,
295+
"DTypeToString": ir.DTypeToString,
296296
}).Parse(tfTrainTemplateText))
297297
if err := trainTemplate.Execute(&program, filler); err != nil {
298298
return "", err
@@ -336,9 +336,9 @@ func Pred(predStmt *ir.PredictStmt, session *pb.Session) (string, error) {
336336
}
337337
var program bytes.Buffer
338338
var predTemplate = template.Must(template.New("Pred").Funcs(template.FuncMap{
339-
"intArrayToJSONString": codegen.MarshalToJSONString,
340-
"attrToPythonValue": codegen.AttrToPythonValue,
341-
"DTypeToString": codegen.DTypeToString,
339+
"intArrayToJSONString": ir.MarshalToJSONString,
340+
"attrToPythonValue": ir.AttrToPythonValue,
341+
"DTypeToString": ir.DTypeToString,
342342
}).Parse(tfPredTemplateText))
343343
if err := predTemplate.Execute(&program, filler); err != nil {
344344
return "", err
@@ -380,9 +380,9 @@ func Explain(stmt *ir.ExplainStmt, session *pb.Session) (string, error) {
380380
}
381381
var program bytes.Buffer
382382
var tmpl = template.Must(template.New("Explain").Funcs(template.FuncMap{
383-
"intArrayToJSONString": codegen.MarshalToJSONString,
384-
"attrToPythonValue": codegen.AttrToPythonValue,
385-
"DTypeToString": codegen.DTypeToString,
383+
"intArrayToJSONString": ir.MarshalToJSONString,
384+
"attrToPythonValue": ir.AttrToPythonValue,
385+
"DTypeToString": ir.DTypeToString,
386386
}).Parse(boostedTreesExplainTemplateText))
387387
if err := tmpl.Execute(&program, filler); err != nil {
388388
return "", err
@@ -421,9 +421,9 @@ func Evaluate(stmt *ir.EvaluateStmt, session *pb.Session) (string, error) {
421421
}
422422
var program bytes.Buffer
423423
var tmpl = template.Must(template.New("Evaluate").Funcs(template.FuncMap{
424-
"intArrayToJSONString": codegen.MarshalToJSONString,
425-
"attrToPythonValue": codegen.AttrToPythonValue,
426-
"DTypeToString": codegen.DTypeToString,
424+
"intArrayToJSONString": ir.MarshalToJSONString,
425+
"attrToPythonValue": ir.AttrToPythonValue,
426+
"DTypeToString": ir.DTypeToString,
427427
}).Parse(tfEvaluateTemplateText))
428428
if err := tmpl.Execute(&program, filler); err != nil {
429429
return "", err

go/codegen/xgboost/codegen.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ type FieldMeta struct {
186186
func resolveFieldMeta(desc *ir.FieldDesc) FieldMeta {
187187
return FieldMeta{
188188
FeatureName: desc.Name,
189-
DType: codegen.DTypeToString(desc.DType),
189+
DType: ir.DTypeToString(desc.DType),
190190
Delimiter: desc.Delimiter,
191191
Format: desc.Format,
192192
Shap: desc.Shape,

go/codegen/codegen_python_values.go renamed to go/ir/codegen_python_values.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,22 @@
1111
// See the License for the specific language governing permissions and
1212
// limitations under the License.
1313

14-
package codegen
14+
package ir
1515

1616
import (
17+
"encoding/json"
1718
"fmt"
1819
"strings"
19-
20-
"sqlflow.org/sqlflow/go/ir"
2120
)
2221

2322
// DTypeToString returns string value of dtype
2423
func DTypeToString(dt int) string {
2524
switch dt {
26-
case ir.Float:
25+
case Float:
2726
return "float32"
28-
case ir.Int:
27+
case Int:
2928
return "int64"
30-
case ir.String:
29+
case String:
3130
return "string"
3231
default:
3332
return ""
@@ -77,3 +76,9 @@ func AttrToPythonValue(attr interface{}) string {
7776
return ""
7877
}
7978
}
79+
80+
// MarshalToJSONString converts any data to JSON string.
81+
func MarshalToJSONString(in interface{}) (string, error) {
82+
bytes, err := json.Marshal(in)
83+
return string(bytes), err
84+
}

go/codegen/codegen_python_values_test.go renamed to go/ir/codegen_python_values_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
// See the License for the specific language governing permissions and
1212
// limitations under the License.
1313

14-
package codegen
14+
package ir
1515

1616
import (
1717
"testing"

go/ir/feature_column.go

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,16 @@
1313

1414
package ir
1515

16-
import "fmt"
16+
import (
17+
"fmt"
18+
"strings"
19+
)
1720

1821
// FeatureColumn corresponds to the COLUMN clause in TO TRAIN.
1922
type FeatureColumn interface {
2023
GetFieldDesc() []*FieldDesc
2124
ApplyTo(*FieldDesc) (FeatureColumn, error)
25+
GenPythonCode() string
2226
}
2327

2428
// CategoryColumn corresponds to categorical column
@@ -43,6 +47,27 @@ type FieldDesc struct {
4347
MaxID int64
4448
}
4549

50+
// GenPythonCode generate Python code to construct a runtime.feature.field_desc
51+
func (fd *FieldDesc) GenPythonCode() string {
52+
isSparseStr := "False"
53+
if fd.IsSparse {
54+
isSparseStr = "True"
55+
}
56+
vocabList := []string{}
57+
for k := range fd.Vocabulary {
58+
vocabList = append(vocabList, k)
59+
}
60+
// pass format = "" to let runtime feature derivation to fill it in.
61+
return fmt.Sprintf(`runtime.feature.field_desc.FieldDesc(name="%s", dtype=fd.DataType.%s, delimiter="%s", format="", shape=%s, is_sparse=%s, vocabulary=%s)`,
62+
fd.Name,
63+
strings.ToUpper(DTypeToString(fd.DType)),
64+
fd.Delimiter,
65+
AttrToPythonValue(fd.Shape),
66+
isSparseStr,
67+
AttrToPythonValue(vocabList),
68+
)
69+
}
70+
4671
// Possible DType values in FieldDesc
4772
const (
4873
Int int = iota
@@ -68,6 +93,12 @@ func (c *NumericColumn) ApplyTo(other *FieldDesc) (FeatureColumn, error) {
6893
return &NumericColumn{other}, nil
6994
}
7095

96+
// GenPythonCode generate Python code to construct a runtime.feature.column.*
97+
func (c *NumericColumn) GenPythonCode() string {
98+
code := fmt.Sprintf(`runtime.feature.column.NumericColumn(%s)`, c.FieldDesc.GenPythonCode())
99+
return code
100+
}
101+
71102
// BucketColumn represents `tf.feature_column.bucketized_column`
72103
// ref: https://www.tensorflow.org/api_docs/python/tf/feature_column/bucketized_column
73104
type BucketColumn struct {
@@ -97,6 +128,15 @@ func (c *BucketColumn) NumClass() int64 {
97128
return int64(len(c.Boundaries)) + 1
98129
}
99130

131+
// GenPythonCode generate Python code to construct a runtime.feature.column.*
132+
func (c *BucketColumn) GenPythonCode() string {
133+
code := fmt.Sprintf(`runtime.feature.column.BucketColumn(%s, %s)`,
134+
c.SourceColumn.GenPythonCode(),
135+
AttrToPythonValue(c.Boundaries),
136+
)
137+
return code
138+
}
139+
100140
// CrossColumn represents `tf.feature_column.crossed_column`
101141
// ref: https://www.tensorflow.org/api_docs/python/tf/feature_column/crossed_column
102142
type CrossColumn struct {
@@ -133,6 +173,23 @@ func (c *CrossColumn) NumClass() int64 {
133173
return c.HashBucketSize
134174
}
135175

176+
// GenPythonCode generate Python code to construct a runtime.feature.column.*
177+
func (c *CrossColumn) GenPythonCode() string {
178+
keysCode := []string{}
179+
for _, k := range c.Keys {
180+
if strKey, ok := k.(string); ok {
181+
keysCode = append(keysCode, strKey)
182+
} else if nc, ok := k.(*NumericColumn); ok {
183+
keysCode = append(keysCode, nc.GenPythonCode())
184+
}
185+
}
186+
code := fmt.Sprintf(`runtime.feature.column.CrossColumn([%s], %d)`,
187+
strings.Join(keysCode, ","),
188+
c.HashBucketSize,
189+
)
190+
return code
191+
}
192+
136193
// CategoryIDColumn represents `tf.feature_column.categorical_column_with_identity`
137194
// ref: https://www.tensorflow.org/api_docs/python/tf/feature_column/categorical_column_with_identity
138195
type CategoryIDColumn struct {
@@ -155,6 +212,15 @@ func (c *CategoryIDColumn) NumClass() int64 {
155212
return c.BucketSize
156213
}
157214

215+
// GenPythonCode generate Python code to construct a runtime.feature.column.*
216+
func (c *CategoryIDColumn) GenPythonCode() string {
217+
code := fmt.Sprintf(`runtime.feature.column.CategoryIDColumn(%s, %d)`,
218+
c.FieldDesc.GenPythonCode(),
219+
c.BucketSize,
220+
)
221+
return code
222+
}
223+
158224
// CategoryHashColumn represents `tf.feature_column.categorical_column_with_hash_bucket`
159225
// ref: https://www.tensorflow.org/api_docs/python/tf/feature_column/categorical_column_with_hash_bucket
160226
type CategoryHashColumn struct {
@@ -177,6 +243,15 @@ func (c *CategoryHashColumn) NumClass() int64 {
177243
return c.BucketSize
178244
}
179245

246+
// GenPythonCode generate Python code to construct a runtime.feature.column.*
247+
func (c *CategoryHashColumn) GenPythonCode() string {
248+
code := fmt.Sprintf(`runtime.feature.column.CategoryHashColumn(%s, %d)`,
249+
c.FieldDesc.GenPythonCode(),
250+
c.BucketSize,
251+
)
252+
return code
253+
}
254+
180255
// SeqCategoryIDColumn represents `tf.feature_column.sequence_categorical_column_with_identity`
181256
// ref: https://www.tensorflow.org/api_docs/python/tf/feature_column/sequence_categorical_column_with_identity
182257
type SeqCategoryIDColumn struct {
@@ -199,6 +274,15 @@ func (c *SeqCategoryIDColumn) NumClass() int64 {
199274
return c.BucketSize
200275
}
201276

277+
// GenPythonCode generate Python code to construct a runtime.feature.column.*
278+
func (c *SeqCategoryIDColumn) GenPythonCode() string {
279+
code := fmt.Sprintf(`runtime.feature.column.SeqCategoryIDColumn(%s, %d)`,
280+
c.FieldDesc.GenPythonCode(),
281+
c.BucketSize,
282+
)
283+
return code
284+
}
285+
202286
// EmbeddingColumn represents `tf.feature_column.embedding_column`
203287
// ref: https://www.tensorflow.org/api_docs/python/tf/feature_column/embedding_column
204288
type EmbeddingColumn struct {
@@ -243,6 +327,24 @@ func (c *EmbeddingColumn) ApplyTo(other *FieldDesc) (FeatureColumn, error) {
243327
return ret, nil
244328
}
245329

330+
// GenPythonCode generate Python code to construct a runtime.feature.column.*
331+
func (c *EmbeddingColumn) GenPythonCode() string {
332+
catColCode := ""
333+
if c.CategoryColumn == nil {
334+
catColCode = "None"
335+
} else {
336+
catColCode = c.CategoryColumn.GenPythonCode()
337+
}
338+
code := fmt.Sprintf(`runtime.feature.column.EmbeddingColumn(category_column=%s, dimension=%d, combiner="%s", initializer="%s", name="%s")`,
339+
catColCode,
340+
c.Dimension,
341+
c.Combiner,
342+
c.Initializer,
343+
c.Name,
344+
)
345+
return code
346+
}
347+
246348
// IndicatorColumn represents `tf.feature_column.indicator_column`
247349
// ref: https://www.tensorflow.org/api_docs/python/tf/feature_column/indicator_column
248350
type IndicatorColumn struct {
@@ -277,3 +379,18 @@ func (c *IndicatorColumn) ApplyTo(other *FieldDesc) (FeatureColumn, error) {
277379
}
278380
return ret, nil
279381
}
382+
383+
// GenPythonCode generate Python code to construct a runtime.feature.column.*
384+
func (c *IndicatorColumn) GenPythonCode() string {
385+
catColCode := ""
386+
if c.CategoryColumn == nil {
387+
catColCode = "None"
388+
} else {
389+
catColCode = c.CategoryColumn.GenPythonCode()
390+
}
391+
code := fmt.Sprintf(`runtime.feature.column.IndicatorColumn(category_column=%s, name="%s")`,
392+
catColCode,
393+
c.Name,
394+
)
395+
return code
396+
}

0 commit comments

Comments
 (0)