Skip to content

Commit 79301b6

Browse files
authored
Make raw column kinds optional (#111)
1 parent 0c98c22 commit 79301b6

23 files changed

+710
-142
lines changed

cli/cmd/get.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ func describeRawColumn(name string, resourcesRes *schema.GetResourcesResponse) (
306306
}
307307
dataStatus := resourcesRes.DataStatuses[rawColumn.GetID()]
308308
out := dataStatusSummary(dataStatus)
309-
out += resourceStr(rawColumn.GetUserConfig())
309+
out += resourceStr(context.GetRawColumnUserConfig(rawColumn))
310310
return out, nil
311311
}
312312

examples/reviews/resources/tokenized_columns.yaml renamed to examples/reviews/resources/columns.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
- kind: environment
2+
name: dev
3+
data:
4+
type: csv
5+
path: s3a://cortex-examples/reviews.csv
6+
csv_config:
7+
header: true
8+
escape: "\""
9+
schema: ["review", "label"]
10+
111
- kind: transformed_column
212
name: embedding_input
313
transformer_path: implementations/transformers/tokenize_string_to_int.py

examples/reviews/resources/raw_columns.yaml

Lines changed: 0 additions & 17 deletions
This file was deleted.

examples/reviews/resources/vocab.yaml

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
1-
- kind: aggregator
2-
name: vocab
3-
output_type: {STRING: INT}
4-
inputs:
5-
columns:
6-
col: STRING_COLUMN
7-
args:
8-
vocab_size: INT
9-
101
- kind: aggregate
112
name: reviews_vocab
12-
aggregator: vocab
3+
aggregator_path: implementations/aggregators/vocab.py
134
inputs:
145
columns:
156
col: review

pkg/operator/api/context/raw_columns.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ type RawColumns map[string]RawColumn
2727
type RawColumn interface {
2828
Column
2929
GetCompute() *userconfig.SparkCompute
30-
GetUserConfig() userconfig.Resource
3130
}
3231

3332
type RawIntColumn struct {
@@ -45,6 +44,11 @@ type RawStringColumn struct {
4544
*ComputedResourceFields
4645
}
4746

47+
type RawInferredColumn struct {
48+
*userconfig.RawInferredColumn
49+
*ComputedResourceFields
50+
}
51+
4852
func (rawColumns RawColumns) OneByID(id string) RawColumn {
4953
for _, rawColumn := range rawColumns {
5054
if rawColumn.GetID() == id {
@@ -79,6 +83,21 @@ func (rawColumns RawColumns) columnInputsID(columnInputValues map[string]interfa
7983
return hash.Any(columnIDMap)
8084
}
8185

86+
func GetRawColumnUserConfig(rawColumn RawColumn) userconfig.Resource {
87+
switch rawColumn.GetType() {
88+
case userconfig.IntegerColumnType:
89+
return rawColumn.(*RawIntColumn).RawIntColumn
90+
case userconfig.FloatColumnType:
91+
return rawColumn.(*RawFloatColumn).RawFloatColumn
92+
case userconfig.StringColumnType:
93+
return rawColumn.(*RawStringColumn).RawStringColumn
94+
case userconfig.InferredColumnType:
95+
return rawColumn.(*RawInferredColumn).RawInferredColumn
96+
}
97+
98+
return nil
99+
}
100+
82101
func (rawColumns RawColumns) ColumnInputsID(columnInputValues map[string]interface{}) string {
83102
return rawColumns.columnInputsID(columnInputValues, false)
84103
}
@@ -98,3 +117,7 @@ func (rawColumn *RawFloatColumn) GetInputRawColumnNames() []string {
98117
func (rawColumn *RawStringColumn) GetInputRawColumnNames() []string {
99118
return []string{rawColumn.GetName()}
100119
}
120+
121+
func (rawColumn *RawInferredColumn) GetInputRawColumnNames() []string {
122+
return []string{rawColumn.GetName()}
123+
}

pkg/operator/api/context/serialize.go

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ import (
2525
)
2626

2727
type RawColumnsTypeSplit struct {
28-
RawIntColumns map[string]*RawIntColumn `json:"raw_int_columns"`
29-
RawStringColumns map[string]*RawStringColumn `json:"raw_string_columns"`
30-
RawFloatColumns map[string]*RawFloatColumn `json:"raw_float_columns"`
28+
RawIntColumns map[string]*RawIntColumn `json:"raw_int_columns"`
29+
RawStringColumns map[string]*RawStringColumn `json:"raw_string_columns"`
30+
RawFloatColumns map[string]*RawFloatColumn `json:"raw_float_columns"`
31+
RawInferredColumns map[string]*RawInferredColumn `json:"raw_inferred_columns"`
3132
}
3233

3334
type DataSplit struct {
@@ -45,6 +46,7 @@ func (ctx Context) splitRawColumns() *RawColumnsTypeSplit {
4546
var rawIntColumns = make(map[string]*RawIntColumn)
4647
var rawFloatColumns = make(map[string]*RawFloatColumn)
4748
var rawStringColumns = make(map[string]*RawStringColumn)
49+
var rawInferredColumns = make(map[string]*RawInferredColumn)
4850
for name, rawColumn := range ctx.RawColumns {
4951
switch typedRawColumn := rawColumn.(type) {
5052
case *RawIntColumn:
@@ -53,13 +55,16 @@ func (ctx Context) splitRawColumns() *RawColumnsTypeSplit {
5355
rawFloatColumns[name] = typedRawColumn
5456
case *RawStringColumn:
5557
rawStringColumns[name] = typedRawColumn
58+
case *RawInferredColumn:
59+
rawInferredColumns[name] = typedRawColumn
5660
}
5761
}
5862

5963
return &RawColumnsTypeSplit{
60-
RawIntColumns: rawIntColumns,
61-
RawFloatColumns: rawFloatColumns,
62-
RawStringColumns: rawStringColumns,
64+
RawIntColumns: rawIntColumns,
65+
RawFloatColumns: rawFloatColumns,
66+
RawStringColumns: rawStringColumns,
67+
RawInferredColumns: rawInferredColumns,
6368
}
6469
}
6570

@@ -75,6 +80,9 @@ func (serial Serial) collectRawColumns() RawColumns {
7580
for name, rawColumn := range serial.RawColumnSplit.RawStringColumns {
7681
rawColumns[name] = rawColumn
7782
}
83+
for name, rawColumn := range serial.RawColumnSplit.RawInferredColumns {
84+
rawColumns[name] = rawColumn
85+
}
7886

7987
return rawColumns
8088
}

pkg/operator/api/userconfig/column_type.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ const (
3131
IntegerListColumnType
3232
FloatListColumnType
3333
StringListColumnType
34+
InferredColumnType
3435
)
3536

3637
var columnTypes = []string{
@@ -41,6 +42,7 @@ var columnTypes = []string{
4142
"INT_LIST_COLUMN",
4243
"FLOAT_LIST_COLUMN",
4344
"STRING_LIST_COLUMN",
45+
"INFERRED_COLUMN",
4446
}
4547

4648
var columnJSONPlaceholders = []string{
@@ -51,6 +53,7 @@ var columnJSONPlaceholders = []string{
5153
"[INT]",
5254
"[FLOAT]",
5355
"[\"STRING\"]",
56+
"VALUE",
5457
}
5558

5659
func ColumnTypeFromString(s string) ColumnType {

pkg/operator/api/userconfig/config.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,9 @@ func (config *Config) Validate(envName string) error {
154154
rawColumnNames := config.RawColumns.Names()
155155
for _, env := range config.Environments {
156156
ingestedColumnNames := env.Data.GetIngestedColumns()
157-
missingColumns := slices.SubtractStrSlice(rawColumnNames, ingestedColumnNames)
158-
if len(missingColumns) > 0 {
159-
return errors.Wrap(ErrorRawColumnNotInEnv(env.Name), Identify(config.RawColumns.Get(missingColumns[0])))
157+
missingColumnNames := slices.SubtractStrSlice(rawColumnNames, ingestedColumnNames)
158+
if len(missingColumnNames) > 0 {
159+
return errors.Wrap(ErrorRawColumnNotInEnv(env.Name), Identify(config.RawColumns.Get(missingColumnNames[0])))
160160
}
161161
extraColumns := slices.SubtractStrSlice(rawColumnNames, ingestedColumnNames)
162162
if len(extraColumns) > 0 {
@@ -440,6 +440,22 @@ func New(configs map[string][]byte, envName string) (*Config, error) {
440440
}
441441
}
442442

443+
for _, env := range config.Environments {
444+
ingestedColumnNames := env.Data.GetIngestedColumns()
445+
missingColumnNames := slices.SubtractStrSlice(ingestedColumnNames, config.RawColumns.Names())
446+
for _, inferredColumnName := range missingColumnNames {
447+
inferredRawColumn := &RawInferredColumn{
448+
ResourceFields: ResourceFields{
449+
Name: inferredColumnName,
450+
},
451+
Type: InferredColumnType,
452+
Compute: &SparkCompute{},
453+
}
454+
cr.Struct(inferredRawColumn.Compute, make(map[string]interface{}), sparkComputeStructValidation)
455+
config.RawColumns = append(config.RawColumns, inferredRawColumn)
456+
}
457+
}
458+
443459
if err := config.Validate(envName); err != nil {
444460
return nil, err
445461
}

pkg/operator/api/userconfig/environments.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
cr "github.com/cortexlabs/cortex/pkg/lib/configreader"
2222
"github.com/cortexlabs/cortex/pkg/lib/errors"
2323
"github.com/cortexlabs/cortex/pkg/lib/pointer"
24+
"github.com/cortexlabs/cortex/pkg/lib/sets/strset"
2425
"github.com/cortexlabs/cortex/pkg/lib/slices"
2526
"github.com/cortexlabs/cortex/pkg/operator/api/resource"
2627
)
@@ -337,6 +338,13 @@ func (environments Environments) Validate() error {
337338
return ErrorDuplicateResourceName(dups...)
338339
}
339340

341+
ingestedColumns := environments[0].Data.GetIngestedColumns()
342+
for _, env := range environments[1:] {
343+
if !strset.New(ingestedColumns...).IsEqual(strset.New(env.Data.GetIngestedColumns()...)) {
344+
return ErrorEnvSchemaMismatch(environments[0], env)
345+
}
346+
}
347+
340348
return nil
341349
}
342350

pkg/operator/api/userconfig/errors.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ const (
5959
ErrRegressionTargetType
6060
ErrClassificationTargetType
6161
ErrSpecifyOnlyOneMissing
62+
ErrEnvSchemaMismatch
6263
)
6364

6465
var errorKinds = []string{
@@ -92,9 +93,10 @@ var errorKinds = []string{
9293
"err_regression_target_type",
9394
"err_classification_target_type",
9495
"err_specify_only_one_missing",
96+
"err_env_schema_mismatch",
9597
}
9698

97-
var _ = [1]int{}[int(ErrSpecifyOnlyOneMissing)-(len(errorKinds)-1)] // Ensure list length matches
99+
var _ = [1]int{}[int(ErrEnvSchemaMismatch)-(len(errorKinds)-1)] // Ensure list length matches
98100

99101
func (t ErrorKind) String() string {
100102
return errorKinds[t]
@@ -397,3 +399,15 @@ func ErrorSpecifyOnlyOneMissing(vals ...string) error {
397399
message: message,
398400
}
399401
}
402+
403+
func ErrorEnvSchemaMismatch(env1, env2 *Environment) error {
404+
return Error{
405+
Kind: ErrEnvSchemaMismatch,
406+
message: fmt.Sprintf("schemas diverge between environments (%s lists %s, and %s lists %s)",
407+
env1.Name,
408+
s.StrsAnd(env1.Data.GetIngestedColumns()),
409+
env2.Name,
410+
s.StrsAnd(env2.Data.GetIngestedColumns()),
411+
),
412+
}
413+
}

pkg/operator/api/userconfig/raw_columns.go

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ type RawColumn interface {
2525
Column
2626
GetType() ColumnType
2727
GetCompute() *SparkCompute
28-
GetUserConfig() Resource
2928
}
3029

3130
type RawColumns []RawColumn
@@ -181,6 +180,12 @@ var rawStringColumnFieldValidations = []*cr.StructFieldValidation{
181180
typeFieldValidation,
182181
}
183182

183+
type RawInferredColumn struct {
184+
ResourceFields
185+
Type ColumnType `json:"type" yaml:"type"`
186+
Compute *SparkCompute `json:"compute" yaml:"compute"`
187+
}
188+
184189
func (rawColumns RawColumns) Validate() error {
185190
resources := make([]Resource, len(rawColumns))
186191
for i, res := range rawColumns {
@@ -224,6 +229,10 @@ func (column *RawStringColumn) GetType() ColumnType {
224229
return column.Type
225230
}
226231

232+
func (column *RawInferredColumn) GetType() ColumnType {
233+
return column.Type
234+
}
235+
227236
func (column *RawIntColumn) GetCompute() *SparkCompute {
228237
return column.Compute
229238
}
@@ -236,6 +245,10 @@ func (column *RawStringColumn) GetCompute() *SparkCompute {
236245
return column.Compute
237246
}
238247

248+
func (column *RawInferredColumn) GetCompute() *SparkCompute {
249+
return column.Compute
250+
}
251+
239252
func (column *RawIntColumn) GetResourceType() resource.Type {
240253
return resource.RawColumnType
241254
}
@@ -248,6 +261,10 @@ func (column *RawStringColumn) GetResourceType() resource.Type {
248261
return resource.RawColumnType
249262
}
250263

264+
func (column *RawInferredColumn) GetResourceType() resource.Type {
265+
return resource.RawColumnType
266+
}
267+
251268
func (column *RawIntColumn) IsRaw() bool {
252269
return true
253270
}
@@ -260,14 +277,6 @@ func (column *RawStringColumn) IsRaw() bool {
260277
return true
261278
}
262279

263-
func (column *RawIntColumn) GetUserConfig() Resource {
264-
return column
265-
}
266-
267-
func (column *RawFloatColumn) GetUserConfig() Resource {
268-
return column
269-
}
270-
271-
func (column *RawStringColumn) GetUserConfig() Resource {
272-
return column
280+
func (column *RawInferredColumn) IsRaw() bool {
281+
return true
273282
}

pkg/operator/api/userconfig/validators.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ func CheckColumnRuntimeTypesMatch(columnRuntimeTypes map[string]interface{}, col
139139
if !ok {
140140
return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeTypeInter, validTypes), columnInputName)
141141
}
142+
143+
if columnRuntimeType == InferredColumnType {
144+
continue
145+
}
146+
142147
if !slices.HasString(validTypes, columnRuntimeType.String()) {
143148
return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeType, validTypes), columnInputName)
144149
}

pkg/operator/context/raw_columns.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,20 @@ func getRawColumns(
9393
},
9494
RawStringColumn: typedColumnConfig,
9595
}
96+
case *userconfig.RawInferredColumn:
97+
buf.WriteString(typedColumnConfig.Name)
98+
id := hash.Bytes(buf.Bytes())
99+
rawColumn = &context.RawInferredColumn{
100+
ComputedResourceFields: &context.ComputedResourceFields{
101+
ResourceFields: &context.ResourceFields{
102+
ID: id,
103+
ResourceType: resource.RawColumnType,
104+
},
105+
},
106+
RawInferredColumn: typedColumnConfig,
107+
}
96108
default:
97-
return nil, errors.Wrap(configreader.ErrorInvalidStr(userconfig.TypeKey, userconfig.IntegerColumnType.String(), userconfig.FloatColumnType.String(), userconfig.StringColumnType.String()), userconfig.Identify(columnConfig)) // unexpected error
109+
return nil, errors.Wrap(configreader.ErrorInvalidStr(typedColumnConfig.GetType().String(), userconfig.IntegerColumnType.String(), userconfig.FloatColumnType.String(), userconfig.StringColumnType.String()), userconfig.Identify(columnConfig)) // unexpected error
98110
}
99111

100112
rawColumns[columnConfig.GetName()] = rawColumn

pkg/operator/context/transformers.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ func loadUserTransformers(
6767
ResourceFields: userconfig.ResourceFields{
6868
Name: implHash,
6969
},
70-
Path: *transColConfig.TransformerPath,
70+
OutputType: userconfig.InferredColumnType,
71+
Path: *transColConfig.TransformerPath,
7172
}
7273
transformer, err := newTransformer(*anonTransformerConfig, impl, nil, pythonPackages)
7374
if err != nil {

0 commit comments

Comments
 (0)