Skip to content

Commit be0f18d

Browse files
committed
Fix ContextFromSerial()
1 parent 7c3a7c1 commit be0f18d

File tree

6 files changed

+257
-166
lines changed

6 files changed

+257
-166
lines changed

pkg/lib/msgpack/msgpack.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ limitations under the License.
1717
package msgpack
1818

1919
import (
20-
"github.com/cortexlabs/cortex/pkg/lib/errors"
2120
"github.com/ugorji/go/codec"
21+
22+
"github.com/cortexlabs/cortex/pkg/lib/errors"
2223
)
2324

2425
var mh codec.MsgpackHandle

pkg/operator/api/context/aggregates.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ type Aggregates map[string]*Aggregate
2525
type Aggregate struct {
2626
*userconfig.Aggregate
2727
*ComputedResourceFields
28-
Type interface{} `json:"type"`
29-
Key string `json:"key"`
28+
Type userconfig.OutputSchema `json:"type"`
29+
Key string `json:"key"`
3030
}
3131

3232
func (aggregates Aggregates) OneByID(id string) *Aggregate {

pkg/operator/api/context/serialize.go

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@ type DataSplit struct {
3737
}
3838

3939
type Serial struct {
40-
Context
40+
*Context
4141
RawColumnSplit *RawColumnsTypeSplit `json:"raw_columns"`
4242
DataSplit *DataSplit `json:"environment_data"`
4343
}
4444

45-
func (ctx Context) splitRawColumns() *RawColumnsTypeSplit {
45+
func (ctx *Context) splitRawColumns() *RawColumnsTypeSplit {
4646
var rawIntColumns = make(map[string]*RawIntColumn)
4747
var rawFloatColumns = make(map[string]*RawFloatColumn)
4848
var rawStringColumns = make(map[string]*RawStringColumn)
@@ -68,7 +68,7 @@ func (ctx Context) splitRawColumns() *RawColumnsTypeSplit {
6868
}
6969
}
7070

71-
func (serial Serial) collectRawColumns() RawColumns {
71+
func (serial *Serial) collectRawColumns() RawColumns {
7272
var rawColumns = make(map[string]RawColumn)
7373

7474
for name, rawColumn := range serial.RawColumnSplit.RawIntColumns {
@@ -87,7 +87,7 @@ func (serial Serial) collectRawColumns() RawColumns {
8787
return rawColumns
8888
}
8989

90-
func (ctx Context) splitEnvironment() *DataSplit {
90+
func (ctx *Context) splitEnvironment() *DataSplit {
9191
var split DataSplit
9292
switch typedData := ctx.Environment.Data.(type) {
9393
case *userconfig.CSVData:
@@ -110,7 +110,85 @@ func (serial *Serial) collectEnvironment() (*Environment, error) {
110110
return serial.Environment, nil
111111
}
112112

113-
func (ctx Context) ToSerial() *Serial {
113+
func (ctx *Context) castSchemaTypes() error {
114+
for _, constant := range ctx.Constants {
115+
if constant.Type != nil {
116+
castedType, err := userconfig.ValidateOutputSchema(constant.Type)
117+
if err != nil {
118+
return err
119+
}
120+
constant.Constant.Type = castedType
121+
}
122+
}
123+
124+
for _, aggregator := range ctx.Aggregators {
125+
if aggregator.OutputType != nil {
126+
casted, err := userconfig.ValidateOutputSchema(aggregator.OutputType)
127+
if err != nil {
128+
return err
129+
}
130+
aggregator.Aggregator.OutputType = casted
131+
}
132+
133+
if aggregator.Input != nil {
134+
casted, err := userconfig.ValidateInputTypeSchema(aggregator.Input.Type, false, true)
135+
if err != nil {
136+
return err
137+
}
138+
aggregator.Aggregator.Input.Type = casted
139+
}
140+
}
141+
142+
for _, aggregate := range ctx.Aggregates {
143+
if aggregate.Type != nil {
144+
casted, err := userconfig.ValidateOutputSchema(aggregate.Type)
145+
if err != nil {
146+
return err
147+
}
148+
aggregate.Type = casted
149+
}
150+
}
151+
152+
for _, transformer := range ctx.Transformers {
153+
if transformer.Input != nil {
154+
casted, err := userconfig.ValidateInputTypeSchema(transformer.Input.Type, false, true)
155+
if err != nil {
156+
return err
157+
}
158+
transformer.Transformer.Input.Type = casted
159+
}
160+
}
161+
162+
for _, estimator := range ctx.Estimators {
163+
if estimator.Input != nil {
164+
casted, err := userconfig.ValidateInputTypeSchema(estimator.Input.Type, false, true)
165+
if err != nil {
166+
return err
167+
}
168+
estimator.Estimator.Input.Type = casted
169+
}
170+
171+
if estimator.TrainingInput != nil {
172+
casted, err := userconfig.ValidateInputTypeSchema(estimator.TrainingInput.Type, false, true)
173+
if err != nil {
174+
return err
175+
}
176+
estimator.Estimator.TrainingInput.Type = casted
177+
}
178+
179+
if estimator.Hparams != nil {
180+
casted, err := userconfig.ValidateInputTypeSchema(estimator.Hparams.Type, true, true)
181+
if err != nil {
182+
return err
183+
}
184+
estimator.Estimator.Hparams.Type = casted
185+
}
186+
}
187+
188+
return nil
189+
}
190+
191+
func (ctx *Context) ToSerial() *Serial {
114192
serial := Serial{
115193
Context: ctx,
116194
RawColumnSplit: ctx.splitRawColumns(),
@@ -120,15 +198,23 @@ func (ctx Context) ToSerial() *Serial {
120198
return &serial
121199
}
122200

123-
func (serial Serial) ContextFromSerial() (*Context, error) {
201+
func (serial *Serial) ContextFromSerial() (*Context, error) {
124202
ctx := serial.Context
203+
125204
ctx.RawColumns = serial.collectRawColumns()
205+
126206
environment, err := serial.collectEnvironment()
127207
if err != nil {
128208
return nil, err
129209
}
130210
ctx.Environment = environment
131-
return &ctx, nil
211+
212+
err = ctx.castSchemaTypes()
213+
if err != nil {
214+
return nil, err
215+
}
216+
217+
return ctx, nil
132218
}
133219

134220
func (ctx Context) ToMsgpackBytes() ([]byte, error) {

pkg/operator/api/userconfig/validators.go

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@ type InputTypeSchema interface{} // CompundType, length-one array of *InputSchem
4141
type OutputSchema interface{} // ValueType, length-one array of OutputSchema, or map of {scalar|ValueType -> OutputSchema} (no *_COLUMN types, compound types, or input options like _default)
4242

4343
func inputSchemaValidator(in interface{}) (interface{}, error) {
44-
return ValidateInputSchema(in, false) // This casts it to *InputSchema
44+
return ValidateInputSchema(in, false, false) // This casts it to *InputSchema
4545
}
4646

4747
func inputSchemaValidatorValueTypesOnly(in interface{}) (interface{}, error) {
48-
return ValidateInputSchema(in, true) // This casts it to *InputSchema
48+
return ValidateInputSchema(in, true, false) // This casts it to *InputSchema
4949
}
5050

51-
func ValidateInputSchema(in interface{}, disallowColumnTypes bool) (*InputSchema, error) {
51+
func ValidateInputSchema(in interface{}, disallowColumnTypes bool, isAlreadyParsed bool) (*InputSchema, error) {
5252
// Check for cortex options vs short form
5353
if inMap, ok := cast.InterfaceToStrInterfaceMap(in); ok {
5454
foundUnderscore, foundNonUnderscore := false, false
@@ -72,7 +72,7 @@ func ValidateInputSchema(in interface{}, disallowColumnTypes bool) (*InputSchema
7272
InterfaceValidation: &cr.InterfaceValidation{
7373
Required: true,
7474
Validator: func(t interface{}) (interface{}, error) {
75-
return validateInputTypeSchema(t, disallowColumnTypes)
75+
return ValidateInputTypeSchema(t, disallowColumnTypes, isAlreadyParsed)
7676
},
7777
},
7878
},
@@ -81,8 +81,10 @@ func ValidateInputSchema(in interface{}, disallowColumnTypes bool) (*InputSchema
8181
BoolValidation: &cr.BoolValidation{},
8282
},
8383
{
84-
StructField: "Default",
85-
InterfaceValidation: &cr.InterfaceValidation{},
84+
StructField: "Default",
85+
InterfaceValidation: &cr.InterfaceValidation{
86+
AllowExplicitNull: isAlreadyParsed,
87+
},
8688
},
8789
{
8890
StructField: "AllowNull",
@@ -92,12 +94,14 @@ func ValidateInputSchema(in interface{}, disallowColumnTypes bool) (*InputSchema
9294
StructField: "MinCount",
9395
Int64PtrValidation: &cr.Int64PtrValidation{
9496
GreaterThanOrEqualTo: pointer.Int64(0),
97+
AllowExplicitNull: isAlreadyParsed,
9598
},
9699
},
97100
{
98101
StructField: "MaxCount",
99102
Int64PtrValidation: &cr.Int64PtrValidation{
100103
GreaterThanOrEqualTo: pointer.Int64(0),
104+
AllowExplicitNull: isAlreadyParsed,
101105
},
102106
},
103107
},
@@ -117,7 +121,7 @@ func ValidateInputSchema(in interface{}, disallowColumnTypes bool) (*InputSchema
117121
}
118122
}
119123

120-
typeSchema, err := validateInputTypeSchema(in, disallowColumnTypes)
124+
typeSchema, err := ValidateInputTypeSchema(in, disallowColumnTypes, isAlreadyParsed)
121125
if err != nil {
122126
return nil, err
123127
}
@@ -132,7 +136,7 @@ func ValidateInputSchema(in interface{}, disallowColumnTypes bool) (*InputSchema
132136
return inputSchema, nil
133137
}
134138

135-
func validateInputTypeSchema(in interface{}, disallowColumnTypes bool) (InputTypeSchema, error) {
139+
func ValidateInputTypeSchema(in interface{}, disallowColumnTypes bool, isAlreadyParsed bool) (InputTypeSchema, error) {
136140
// String
137141
if inStr, ok := in.(string); ok {
138142
compoundType, err := CompoundTypeFromString(inStr)
@@ -150,7 +154,7 @@ func validateInputTypeSchema(in interface{}, disallowColumnTypes bool) (InputTyp
150154
if len(inSlice) != 1 {
151155
return nil, ErrorTypeListLength(inSlice)
152156
}
153-
inputSchema, err := ValidateInputSchema(inSlice[0], disallowColumnTypes)
157+
inputSchema, err := ValidateInputSchema(inSlice[0], disallowColumnTypes, isAlreadyParsed)
154158
if err != nil {
155159
return nil, errors.Wrap(err, s.Index(0))
156160
}
@@ -182,7 +186,7 @@ func validateInputTypeSchema(in interface{}, disallowColumnTypes bool) (InputTyp
182186
if disallowColumnTypes && typeKey.IsColumns() {
183187
return nil, ErrorColumnTypeNotAllowed(typeKey)
184188
}
185-
valueInputSchema, err := ValidateInputSchema(typeValue, disallowColumnTypes)
189+
valueInputSchema, err := ValidateInputSchema(typeValue, disallowColumnTypes, isAlreadyParsed)
186190
if err != nil {
187191
return nil, errors.Wrap(err, string(typeKey))
188192
}
@@ -201,7 +205,7 @@ func validateInputTypeSchema(in interface{}, disallowColumnTypes bool) (InputTyp
201205
}
202206
}
203207

204-
valueInputSchema, err := ValidateInputSchema(value, disallowColumnTypes)
208+
valueInputSchema, err := ValidateInputSchema(value, disallowColumnTypes, isAlreadyParsed)
205209
if err != nil {
206210
return nil, errors.Wrap(err, s.UserStrStripped(key))
207211
}

0 commit comments

Comments
 (0)