1313
1414package ir
1515
16- import "fmt"
16+ import (
17+ "fmt"
18+ "strings"
19+ )
1720
1821// FeatureColumn corresponds to the COLUMN clause in TO TRAIN.
1922type FeatureColumn interface {
2023 GetFieldDesc () []* FieldDesc
2124 ApplyTo (* FieldDesc ) (FeatureColumn , error )
25+ GenPythonCode () string
2226}
2327
2428// CategoryColumn corresponds to categorical column
@@ -43,6 +47,27 @@ type FieldDesc struct {
4347 MaxID int64
4448}
4549
50+ // GenPythonCode generate Python code to construct a runtime.feature.field_desc
51+ func (fd * FieldDesc ) GenPythonCode () string {
52+ isSparseStr := "False"
53+ if fd .IsSparse {
54+ isSparseStr = "True"
55+ }
56+ vocabList := []string {}
57+ for k := range fd .Vocabulary {
58+ vocabList = append (vocabList , k )
59+ }
60+ // pass format = "" to let runtime feature derivation to fill it in.
61+ return fmt .Sprintf (`runtime.feature.field_desc.FieldDesc(name="%s", dtype=fd.DataType.%s, delimiter="%s", format="", shape=%s, is_sparse=%s, vocabulary=%s)` ,
62+ fd .Name ,
63+ strings .ToUpper (DTypeToString (fd .DType )),
64+ fd .Delimiter ,
65+ AttrToPythonValue (fd .Shape ),
66+ isSparseStr ,
67+ AttrToPythonValue (vocabList ),
68+ )
69+ }
70+
4671// Possible DType values in FieldDesc
4772const (
4873 Int int = iota
@@ -68,6 +93,12 @@ func (c *NumericColumn) ApplyTo(other *FieldDesc) (FeatureColumn, error) {
6893 return & NumericColumn {other }, nil
6994}
7095
96+ // GenPythonCode generate Python code to construct a runtime.feature.column.*
97+ func (c * NumericColumn ) GenPythonCode () string {
98+ code := fmt .Sprintf (`runtime.feature.column.NumericColumn(%s)` , c .FieldDesc .GenPythonCode ())
99+ return code
100+ }
101+
71102// BucketColumn represents `tf.feature_column.bucketized_column`
72103// ref: https://www.tensorflow.org/api_docs/python/tf/feature_column/bucketized_column
73104type BucketColumn struct {
@@ -97,6 +128,15 @@ func (c *BucketColumn) NumClass() int64 {
97128 return int64 (len (c .Boundaries )) + 1
98129}
99130
131+ // GenPythonCode generate Python code to construct a runtime.feature.column.*
132+ func (c * BucketColumn ) GenPythonCode () string {
133+ code := fmt .Sprintf (`runtime.feature.column.BucketColumn(%s, %s)` ,
134+ c .SourceColumn .GenPythonCode (),
135+ AttrToPythonValue (c .Boundaries ),
136+ )
137+ return code
138+ }
139+
100140// CrossColumn represents `tf.feature_column.crossed_column`
101141// ref: https://www.tensorflow.org/api_docs/python/tf/feature_column/crossed_column
102142type CrossColumn struct {
@@ -133,6 +173,23 @@ func (c *CrossColumn) NumClass() int64 {
133173 return c .HashBucketSize
134174}
135175
176+ // GenPythonCode generate Python code to construct a runtime.feature.column.*
177+ func (c * CrossColumn ) GenPythonCode () string {
178+ keysCode := []string {}
179+ for _ , k := range c .Keys {
180+ if strKey , ok := k .(string ); ok {
181+ keysCode = append (keysCode , strKey )
182+ } else if nc , ok := k .(* NumericColumn ); ok {
183+ keysCode = append (keysCode , nc .GenPythonCode ())
184+ }
185+ }
186+ code := fmt .Sprintf (`runtime.feature.column.CrossColumn([%s], %d)` ,
187+ strings .Join (keysCode , "," ),
188+ c .HashBucketSize ,
189+ )
190+ return code
191+ }
192+
136193// CategoryIDColumn represents `tf.feature_column.categorical_column_with_identity`
137194// ref: https://www.tensorflow.org/api_docs/python/tf/feature_column/categorical_column_with_identity
138195type CategoryIDColumn struct {
@@ -155,6 +212,15 @@ func (c *CategoryIDColumn) NumClass() int64 {
155212 return c .BucketSize
156213}
157214
215+ // GenPythonCode generate Python code to construct a runtime.feature.column.*
216+ func (c * CategoryIDColumn ) GenPythonCode () string {
217+ code := fmt .Sprintf (`runtime.feature.column.CategoryIDColumn(%s, %d)` ,
218+ c .FieldDesc .GenPythonCode (),
219+ c .BucketSize ,
220+ )
221+ return code
222+ }
223+
158224// CategoryHashColumn represents `tf.feature_column.categorical_column_with_hash_bucket`
159225// ref: https://www.tensorflow.org/api_docs/python/tf/feature_column/categorical_column_with_hash_bucket
160226type CategoryHashColumn struct {
@@ -177,6 +243,15 @@ func (c *CategoryHashColumn) NumClass() int64 {
177243 return c .BucketSize
178244}
179245
246+ // GenPythonCode generate Python code to construct a runtime.feature.column.*
247+ func (c * CategoryHashColumn ) GenPythonCode () string {
248+ code := fmt .Sprintf (`runtime.feature.column.CategoryHashColumn(%s, %d)` ,
249+ c .FieldDesc .GenPythonCode (),
250+ c .BucketSize ,
251+ )
252+ return code
253+ }
254+
180255// SeqCategoryIDColumn represents `tf.feature_column.sequence_categorical_column_with_identity`
181256// ref: https://www.tensorflow.org/api_docs/python/tf/feature_column/sequence_categorical_column_with_identity
182257type SeqCategoryIDColumn struct {
@@ -199,6 +274,15 @@ func (c *SeqCategoryIDColumn) NumClass() int64 {
199274 return c .BucketSize
200275}
201276
277+ // GenPythonCode generate Python code to construct a runtime.feature.column.*
278+ func (c * SeqCategoryIDColumn ) GenPythonCode () string {
279+ code := fmt .Sprintf (`runtime.feature.column.SeqCategoryIDColumn(%s, %d)` ,
280+ c .FieldDesc .GenPythonCode (),
281+ c .BucketSize ,
282+ )
283+ return code
284+ }
285+
202286// EmbeddingColumn represents `tf.feature_column.embedding_column`
203287// ref: https://www.tensorflow.org/api_docs/python/tf/feature_column/embedding_column
204288type EmbeddingColumn struct {
@@ -243,6 +327,24 @@ func (c *EmbeddingColumn) ApplyTo(other *FieldDesc) (FeatureColumn, error) {
243327 return ret , nil
244328}
245329
330+ // GenPythonCode generate Python code to construct a runtime.feature.column.*
331+ func (c * EmbeddingColumn ) GenPythonCode () string {
332+ catColCode := ""
333+ if c .CategoryColumn == nil {
334+ catColCode = "None"
335+ } else {
336+ catColCode = c .CategoryColumn .GenPythonCode ()
337+ }
338+ code := fmt .Sprintf (`runtime.feature.column.EmbeddingColumn(category_column=%s, dimension=%d, combiner="%s", initializer="%s", name="%s")` ,
339+ catColCode ,
340+ c .Dimension ,
341+ c .Combiner ,
342+ c .Initializer ,
343+ c .Name ,
344+ )
345+ return code
346+ }
347+
246348// IndicatorColumn represents `tf.feature_column.indicator_column`
247349// ref: https://www.tensorflow.org/api_docs/python/tf/feature_column/indicator_column
248350type IndicatorColumn struct {
@@ -277,3 +379,18 @@ func (c *IndicatorColumn) ApplyTo(other *FieldDesc) (FeatureColumn, error) {
277379 }
278380 return ret , nil
279381}
382+
383+ // GenPythonCode generate Python code to construct a runtime.feature.column.*
384+ func (c * IndicatorColumn ) GenPythonCode () string {
385+ catColCode := ""
386+ if c .CategoryColumn == nil {
387+ catColCode = "None"
388+ } else {
389+ catColCode = c .CategoryColumn .GenPythonCode ()
390+ }
391+ code := fmt .Sprintf (`runtime.feature.column.IndicatorColumn(category_column=%s, name="%s")` ,
392+ catColCode ,
393+ c .Name ,
394+ )
395+ return code
396+ }
0 commit comments