Skip to content

Make raw column kinds optional #111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Jun 3, 2019
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cli/cmd/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ func describeRawColumn(name string, resourcesRes *schema.GetResourcesResponse) (
}
dataStatus := resourcesRes.DataStatuses[rawColumn.GetID()]
out := dataStatusSummary(dataStatus)
out += resourceStr(rawColumn.GetUserConfig())
out += resourceStr(rawColumn)
return out, nil
}

Expand Down
17 changes: 0 additions & 17 deletions examples/reviews/resources/raw_columns.yaml

This file was deleted.

10 changes: 10 additions & 0 deletions examples/reviews/resources/tokenized_columns.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
- kind: environment
name: dev
data:
type: csv
path: s3a://cortex-examples/reviews.csv
csv_config:
header: true
escape: "\""
schema: ["review", "label"]

- kind: transformed_column
name: embedding_input
transformer_path: implementations/transformers/tokenize_string_to_int.py
Expand Down
11 changes: 1 addition & 10 deletions examples/reviews/resources/vocab.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
- kind: aggregator
name: vocab
output_type: {STRING: INT}
inputs:
columns:
col: STRING_COLUMN
args:
vocab_size: INT

- kind: aggregate
name: reviews_vocab
aggregator: vocab
aggregator_path: implementations/aggregators/vocab.py
inputs:
columns:
col: review
Expand Down
10 changes: 9 additions & 1 deletion pkg/operator/api/context/raw_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ type RawColumns map[string]RawColumn
type RawColumn interface {
Column
GetCompute() *userconfig.SparkCompute
GetUserConfig() userconfig.Resource
}

type RawIntColumn struct {
Expand All @@ -45,6 +44,11 @@ type RawStringColumn struct {
*ComputedResourceFields
}

type RawInferredColumn struct {
*userconfig.RawInferredColumn
*ComputedResourceFields
}

func (rawColumns RawColumns) OneByID(id string) RawColumn {
for _, rawColumn := range rawColumns {
if rawColumn.GetID() == id {
Expand Down Expand Up @@ -98,3 +102,7 @@ func (rawColumn *RawFloatColumn) GetInputRawColumnNames() []string {
func (rawColumn *RawStringColumn) GetInputRawColumnNames() []string {
return []string{rawColumn.GetName()}
}

func (rawColumn *RawInferredColumn) GetInputRawColumnNames() []string {
return []string{rawColumn.GetName()}
}
20 changes: 14 additions & 6 deletions pkg/operator/api/context/serialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ import (
)

type RawColumnsTypeSplit struct {
RawIntColumns map[string]*RawIntColumn `json:"raw_int_columns"`
RawStringColumns map[string]*RawStringColumn `json:"raw_string_columns"`
RawFloatColumns map[string]*RawFloatColumn `json:"raw_float_columns"`
RawIntColumns map[string]*RawIntColumn `json:"raw_int_columns"`
RawStringColumns map[string]*RawStringColumn `json:"raw_string_columns"`
RawFloatColumns map[string]*RawFloatColumn `json:"raw_float_columns"`
RawInferredColumns map[string]*RawInferredColumn `json:"raw_inferred_columns"`
}

type DataSplit struct {
Expand All @@ -45,6 +46,7 @@ func (ctx Context) splitRawColumns() *RawColumnsTypeSplit {
var rawIntColumns = make(map[string]*RawIntColumn)
var rawFloatColumns = make(map[string]*RawFloatColumn)
var rawStringColumns = make(map[string]*RawStringColumn)
var rawInferredColumns = make(map[string]*RawInferredColumn)
for name, rawColumn := range ctx.RawColumns {
switch typedRawColumn := rawColumn.(type) {
case *RawIntColumn:
Expand All @@ -53,13 +55,16 @@ func (ctx Context) splitRawColumns() *RawColumnsTypeSplit {
rawFloatColumns[name] = typedRawColumn
case *RawStringColumn:
rawStringColumns[name] = typedRawColumn
case *RawInferredColumn:
rawInferredColumns[name] = typedRawColumn
}
}

return &RawColumnsTypeSplit{
RawIntColumns: rawIntColumns,
RawFloatColumns: rawFloatColumns,
RawStringColumns: rawStringColumns,
RawIntColumns: rawIntColumns,
RawFloatColumns: rawFloatColumns,
RawStringColumns: rawStringColumns,
RawInferredColumns: rawInferredColumns,
}
}

Expand All @@ -75,6 +80,9 @@ func (serial Serial) collectRawColumns() RawColumns {
for name, rawColumn := range serial.RawColumnSplit.RawStringColumns {
rawColumns[name] = rawColumn
}
for name, rawColumn := range serial.RawColumnSplit.RawInferredColumns {
rawColumns[name] = rawColumn
}

return rawColumns
}
Expand Down
3 changes: 3 additions & 0 deletions pkg/operator/api/userconfig/column_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ const (
IntegerListColumnType
FloatListColumnType
StringListColumnType
InferredColumnType
)

var columnTypes = []string{
Expand All @@ -41,6 +42,7 @@ var columnTypes = []string{
"INT_LIST_COLUMN",
"FLOAT_LIST_COLUMN",
"STRING_LIST_COLUMN",
"INFERRED_COLUMN",
}

var columnJSONPlaceholders = []string{
Expand All @@ -51,6 +53,7 @@ var columnJSONPlaceholders = []string{
"[INT]",
"[FLOAT]",
"[\"STRING\"]",
"INFER",
}

func ColumnTypeFromString(s string) ColumnType {
Expand Down
24 changes: 24 additions & 0 deletions pkg/operator/api/userconfig/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"io/ioutil"
"strings"

k8sresource "k8s.io/apimachinery/pkg/api/resource"

"github.com/cortexlabs/cortex/pkg/lib/cast"
"github.com/cortexlabs/cortex/pkg/lib/configreader"
cr "github.com/cortexlabs/cortex/pkg/lib/configreader"
Expand Down Expand Up @@ -440,6 +442,28 @@ func New(configs map[string][]byte, envName string) (*Config, error) {
}
}

rawColumnNames := config.RawColumns.Names()
for _, env := range config.Environments {
ingestedColumnNames := env.Data.GetIngestedColumns()
missingColumns := slices.SubtractStrSlice(ingestedColumnNames, rawColumnNames)
for _, inferredColumn := range missingColumns {
inferredRawColumn := &RawInferredColumn{
ResourceFields: ResourceFields{
Name: inferredColumn,
},
Type: InferredColumnType,
Compute: &SparkCompute{
Executors: 1,
DriverCPU: Quantity{Quantity: k8sresource.MustParse("1")},
ExecutorCPU: Quantity{Quantity: k8sresource.MustParse("1")},
DriverMem: Quantity{Quantity: k8sresource.MustParse("500Mi")},
ExecutorMem: Quantity{Quantity: k8sresource.MustParse("500Mi")},
},
}
config.RawColumns = append(config.RawColumns, inferredRawColumn)
}
}

if err := config.Validate(envName); err != nil {
return nil, err
}
Expand Down
43 changes: 32 additions & 11 deletions pkg/operator/api/userconfig/raw_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ type RawColumn interface {
Column
GetType() ColumnType
GetCompute() *SparkCompute
GetUserConfig() Resource
}

type RawColumns []RawColumn
Expand Down Expand Up @@ -181,6 +180,24 @@ var rawStringColumnFieldValidations = []*cr.StructFieldValidation{
typeFieldValidation,
}

type RawInferredColumn struct {
ResourceFields
Type ColumnType `json:"type" yaml:"type"`
Compute *SparkCompute `json:"compute" yaml:"compute"`
}

var rawInferredColumnFieldValidations = []*cr.StructFieldValidation{
{
Key: "name",
StructField: "Name",
StringValidation: &cr.StringValidation{
AlphaNumericDashUnderscore: true,
Required: true,
},
},
sparkComputeFieldValidation("Compute"),
}

func (rawColumns RawColumns) Validate() error {
resources := make([]Resource, len(rawColumns))
for i, res := range rawColumns {
Expand Down Expand Up @@ -224,6 +241,10 @@ func (column *RawStringColumn) GetType() ColumnType {
return column.Type
}

func (column *RawInferredColumn) GetType() ColumnType {
return column.Type
}

func (column *RawIntColumn) GetCompute() *SparkCompute {
return column.Compute
}
Expand All @@ -236,6 +257,10 @@ func (column *RawStringColumn) GetCompute() *SparkCompute {
return column.Compute
}

func (column *RawInferredColumn) GetCompute() *SparkCompute {
return column.Compute
}

func (column *RawIntColumn) GetResourceType() resource.Type {
return resource.RawColumnType
}
Expand All @@ -248,6 +273,10 @@ func (column *RawStringColumn) GetResourceType() resource.Type {
return resource.RawColumnType
}

func (column *RawInferredColumn) GetResourceType() resource.Type {
return resource.RawColumnType
}

func (column *RawIntColumn) IsRaw() bool {
return true
}
Expand All @@ -260,14 +289,6 @@ func (column *RawStringColumn) IsRaw() bool {
return true
}

func (column *RawIntColumn) GetUserConfig() Resource {
return column
}

func (column *RawFloatColumn) GetUserConfig() Resource {
return column
}

func (column *RawStringColumn) GetUserConfig() Resource {
return column
func (column *RawInferredColumn) IsRaw() bool {
return true
}
5 changes: 5 additions & 0 deletions pkg/operator/api/userconfig/validators.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ func CheckColumnRuntimeTypesMatch(columnRuntimeTypes map[string]interface{}, col
if !ok {
return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeTypeInter, validTypes), columnInputName)
}

if columnRuntimeType == InferredColumnType {
continue
}

if !slices.HasString(validTypes, columnRuntimeType.String()) {
return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeType, validTypes), columnInputName)
}
Expand Down
14 changes: 13 additions & 1 deletion pkg/operator/context/raw_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,20 @@ func getRawColumns(
},
RawStringColumn: typedColumnConfig,
}
case *userconfig.RawInferredColumn:
buf.WriteString(typedColumnConfig.Name)
id := hash.Bytes(buf.Bytes())
rawColumn = &context.RawInferredColumn{
ComputedResourceFields: &context.ComputedResourceFields{
ResourceFields: &context.ResourceFields{
ID: id,
ResourceType: resource.RawColumnType,
},
},
RawInferredColumn: typedColumnConfig,
}
default:
return nil, errors.Wrap(configreader.ErrorInvalidStr(userconfig.TypeKey, userconfig.IntegerColumnType.String(), userconfig.FloatColumnType.String(), userconfig.StringColumnType.String()), userconfig.Identify(columnConfig)) // unexpected error
return nil, errors.Wrap(configreader.ErrorInvalidStr(typedColumnConfig.GetType().String(), userconfig.IntegerColumnType.String(), userconfig.FloatColumnType.String(), userconfig.StringColumnType.String()), userconfig.Identify(columnConfig)) // unexpected error
}

rawColumns[columnConfig.GetName()] = rawColumn
Expand Down
2 changes: 2 additions & 0 deletions pkg/workloads/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
COLUMN_TYPE_INT_LIST = "INT_LIST_COLUMN"
COLUMN_TYPE_FLOAT_LIST = "FLOAT_LIST_COLUMN"
COLUMN_TYPE_STRING_LIST = "STRING_LIST_COLUMN"
COLUMN_TYPE_INFERRED = "INFERRED_COLUMN"

COLUMN_LIST_TYPES = [COLUMN_TYPE_INT_LIST, COLUMN_TYPE_FLOAT_LIST, COLUMN_TYPE_STRING_LIST]

Expand All @@ -30,6 +31,7 @@
COLUMN_TYPE_INT_LIST,
COLUMN_TYPE_FLOAT_LIST,
COLUMN_TYPE_STRING_LIST,
COLUMN_TYPE_INFERRED,
]

VALUE_TYPE_INT = "INT"
Expand Down
8 changes: 2 additions & 6 deletions pkg/workloads/lib/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def get_metadata(self, resource_id, use_cache=True):
def get_inferred_column_type(self, column_name):
column = self.columns[column_name]
column_type = self.columns[column_name].get("type", "unknown")
if column_type == "unknown":
if column_type == "unknown" or column_type == "INFERRED_COLUMN":
column_type = self.get_metadata(column["id"])["type"]
self.columns[column_name]["type"] = column_type

Expand Down Expand Up @@ -575,11 +575,7 @@ def create_inputs_map(values_map, input_config):

def _deserialize_raw_ctx(raw_ctx):
raw_columns = raw_ctx["raw_columns"]
raw_ctx["raw_columns"] = util.merge_dicts_overwrite(
raw_columns["raw_int_columns"],
raw_columns["raw_float_columns"],
raw_columns["raw_string_columns"],
)
raw_ctx["raw_columns"] = util.merge_dicts_overwrite(*raw_columns.values())

data_split = raw_ctx["environment_data"]

Expand Down
18 changes: 0 additions & 18 deletions pkg/workloads/lib/tf_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,6 @@
}


def add_tf_types(config):
if not util.is_dict(config):
return

type_fields = {}
for k, v in config.items():
if util.is_str(k) and util.is_str(v) and v in consts.COLUMN_TYPES:
type_fields[k] = v
elif util.is_dict(v):
add_tf_types(v)
elif util.is_list(v):
for sub_v in v:
add_tf_types(sub_v)

for k, v in type_fields.items():
config[k + "_tf"] = CORTEX_TYPE_TO_TF_TYPE[v]


def set_logging_verbosity(verbosity):
tf.logging.set_verbosity(verbosity)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = str(tf.logging.__dict__[verbosity] / 10)
Expand Down
4 changes: 0 additions & 4 deletions pkg/workloads/lib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,6 @@ def flatten(var):
return [var]


def subtract_lists(lst1, lst2):
return [e for e in lst1 if e not in lst2]


def keep_dict_keys(d, keys):
key_set = set(keys)
for key in list(d.keys()):
Expand Down
Loading