Skip to content

Commit a137b12

Browse files
committed
Clean up
1 parent da86b23 commit a137b12

File tree

5 files changed

+93
-83
lines changed

5 files changed

+93
-83
lines changed

pkg/consts/consts.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ var (
4646
AggregatesDir = "aggregates"
4747
TransformersDir = "transformers"
4848
EstimatorsDir = "estimators"
49-
ModelImplsDir = "model_implementations"
5049
PythonPackagesDir = "python_packages"
5150
ModelsDir = "models"
5251
ConstantsDir = "constants"

pkg/workloads/lib/context.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -415,14 +415,15 @@ def get_inferred_column_type(self, column_name):
415415

416416
return column_type
417417

418-
# replaces column references with column names (unless preserve_column_refs = true, then leaves them untouched)
418+
# Replace aggregates and constants with their values, and columns with their names (unless preserve_column_refs == False)
419+
# Also validate against input_schema (if not None)
419420
def populate_values(self, input, input_schema, preserve_column_refs):
420421
if input is None:
421422
if input_schema is None:
422423
return None
423-
if input_schema["_allow_null"]:
424+
if input_schema.get("_allow_null") == True:
424425
return None
425-
raise UserException("Null is not allowed")
426+
raise UserException("Null value is not allowed")
426427

427428
if util.is_resource_ref(input):
428429
res_name = util.get_resource_ref(input)
@@ -447,8 +448,10 @@ def populate_values(self, input, input_schema, preserve_column_refs):
447448
col_type = self.get_inferred_column_type(res_name)
448449
if col_type not in input_schema["_type"]:
449450
raise UserException(
450-
"column {}: column type mismatch: got {}, expected {}".format(
451-
res_name, col_type, input_schema["_type"]
451+
"column {}: unsupported input type (expected type {}, got type {})".format(
452+
res_name,
453+
util.data_type_str(input_schema["_type"]),
454+
util.data_type_str(col_type),
452455
)
453456
)
454457
if preserve_column_refs:
@@ -460,21 +463,25 @@ def populate_values(self, input, input_schema, preserve_column_refs):
460463
elem_schema = None
461464
if input_schema is not None:
462465
if not util.is_list(input_schema["_type"]):
463-
raise UserException("unexpected type (list)")
466+
raise UserException(
467+
"unsupported input type (expected type {}, got {})".format(
468+
util.data_type_str(input_schema["_type"]), util.pp_str_flat(input)
469+
)
470+
)
464471
elem_schema = input_schema["_type"][0]
465472

466473
min_count = input_schema.get("_min_count")
467474
if min_count is not None and len(input) < min_count:
468475
raise UserException(
469-
"list has length {}, but the minimum length is {}".format(
476+
"list has length {}, but the minimum allowed length is {}".format(
470477
len(input), min_count
471478
)
472479
)
473480

474481
max_count = input_schema.get("_max_count")
475482
if max_count is not None and len(input) > max_count:
476483
raise UserException(
477-
"list has length {}, but the maximum length is {}".format(
484+
"list has length {}, but the maximum allowed length is {}".format(
478485
len(input), max_count
479486
)
480487
)
@@ -496,24 +503,32 @@ def populate_values(self, input, input_schema, preserve_column_refs):
496503
try:
497504
val_casted = self.populate_values(val, None, preserve_column_refs)
498505
except CortexException as e:
499-
e.wrap(util.pp_str_flat(key_casted))
506+
e.wrap(util.pp_str_flat(key))
500507
raise
501508
casted[key_casted] = val_casted
502509
return casted
503510

504511
if not util.is_dict(input_schema["_type"]):
505-
raise UserException("unexpected type (map)")
512+
raise UserException(
513+
"unsupported input type (expected type {}, got {})".format(
514+
util.data_type_str(input_schema["_type"]), util.pp_str_flat(input)
515+
)
516+
)
506517

507518
min_count = input_schema.get("_min_count")
508519
if min_count is not None and len(input) < min_count:
509520
raise UserException(
510-
"map has length {}, but the minimum length is {}".format(len(input), min_count)
521+
"map has length {}, but the minimum allowed length is {}".format(
522+
len(input), min_count
523+
)
511524
)
512525

513526
max_count = input_schema.get("_max_count")
514527
if max_count is not None and len(input) > max_count:
515528
raise UserException(
516-
"map has length {}, but the maximum length is {}".format(len(input), max_count)
529+
"map has length {}, but the maximum allowed length is {}".format(
530+
len(input), max_count
531+
)
517532
)
518533

519534
is_generic_map = False
@@ -535,23 +550,23 @@ def populate_values(self, input, input_schema, preserve_column_refs):
535550
val, generic_map_value, preserve_column_refs
536551
)
537552
except CortexException as e:
538-
e.wrap(util.pp_str_flat(key_casted))
553+
e.wrap(util.pp_str_flat(key))
539554
raise
540555
casted[key_casted] = val_casted
541556
return casted
542557

543558
# fixed map
544559
casted = {}
545560
for key, val_schema in input_schema["_type"].items():
546-
default = None
547-
if key not in input:
561+
if key in input:
562+
val = input[key]
563+
else:
548564
if val_schema.get("_optional") is not True:
549565
raise UserException("missing key: " + util.pp_str_flat(key))
550566
if val_schema.get("_default") is None:
551567
continue
552-
default = val_schema["_default"]
568+
val = val_schema["_default"]
553569

554-
val = input.get(key, default)
555570
try:
556571
val_casted = self.populate_values(val, val_schema, preserve_column_refs)
557572
except CortexException as e:
@@ -562,8 +577,12 @@ def populate_values(self, input, input_schema, preserve_column_refs):
562577

563578
if input_schema is None:
564579
return input
565-
if util.is_list(input_schema["_type"]) or util.is_dict(input_schema["_type"]):
566-
raise UserException("unexpected type (scalar)")
580+
if not util.is_str(input_schema["_type"]):
581+
raise UserException(
582+
"unsupported input type (expected type {}, got {})".format(
583+
util.data_type_str(input_schema["_type"]), util.pp_str_flat(input)
584+
)
585+
)
567586
return cast_compound_type(input, input_schema["_type"])
568587

569588

@@ -605,8 +624,8 @@ def cast_compound_type(value, type_str):
605624
return value
606625

607626
raise UserException(
608-
"input value's type is not supported by the schema (got {}, expected input with type {})".format(
609-
util.pp_str_flat(value), type_str
627+
"unsupported input type (expected type {}, got {})".format(
628+
util.data_type_str(type_str), util.pp_str_flat(value)
610629
)
611630
)
612631

@@ -689,7 +708,7 @@ def _deserialize_raw_ctx(raw_ctx):
689708
def create_transformer_inputs_from_map(input, col_value_map):
690709
if util.is_str(input):
691710
res_name = util.get_resource_ref(input)
692-
if res_name in col_value_map:
711+
if res_name is not None and res_name in col_value_map:
693712
return col_value_map[res_name]
694713
return input
695714

pkg/workloads/lib/util.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ def pp_str_flat(obj, indent=0):
6666
return indent_str(out, indent)
6767

6868

69+
def data_type_str(obj):
70+
# TODO. Also call this method with output types?
71+
return pp_str_flat(obj)
72+
73+
6974
def log_indent(obj, indent=0, logging_func=logger.info):
7075
if not is_str(obj):
7176
text = repr(obj)
@@ -748,8 +753,6 @@ def validate_output_type(value, output_type):
748753
return False
749754

750755
if is_list(output_type):
751-
if not (len(output_type) == 1 and is_str(output_type[0])):
752-
return False
753756
if not is_list(value):
754757
return False
755758
for value_item in value:
@@ -760,8 +763,6 @@ def validate_output_type(value, output_type):
760763
if is_dict(output_type):
761764
if not is_dict(value):
762765
return False
763-
if len(output_type) == 0:
764-
return False
765766

766767
is_generic_map = False
767768
if len(output_type) == 1:
@@ -787,10 +788,10 @@ def validate_output_type(value, output_type):
787788
return False
788789
return True
789790

790-
return False
791+
return False # unexpected
791792

792793

793-
# Casts int -> float. Input is assumed to be already validated
794+
# value is assumed to be already validated against output_type
794795
def cast_output_type(value, output_type):
795796
if is_str(output_type):
796797
if (
@@ -858,17 +859,17 @@ def extract_resource_refs(input):
858859
return {res}
859860
return set()
860861

862+
if is_list(input):
863+
resources = set()
864+
for item in input:
865+
resources = resources.union(extract_resource_refs(item))
866+
return resources
867+
861868
if is_dict(input):
862869
resources = set()
863870
for key, val in input.items():
864871
resources = resources.union(extract_resource_refs(key))
865872
resources = resources.union(extract_resource_refs(val))
866873
return resources
867874

868-
if is_list(input):
869-
resources = set()
870-
for item in input:
871-
resources = resources.union(extract_resource_refs(item))
872-
return resources
873-
874875
return set()

pkg/workloads/spark_job/spark_util.py

Lines changed: 32 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -397,12 +397,12 @@ def read_parquet(ctx, spark):
397397

398398

399399
def split_aggregators(aggregate_names, ctx):
400-
aggregate_resources = [ctx.aggregates[agg_name] for agg_name in aggregate_names]
400+
aggregates = [ctx.aggregates[agg_name] for agg_name in aggregate_names]
401401

402402
builtin_aggregates = []
403403
custom_aggregates = []
404404

405-
for agg in aggregate_resources:
405+
for agg in aggregates:
406406
aggregator = ctx.aggregators[agg["aggregator"]]
407407
if aggregator.get("namespace", None) == "cortex" and aggregator["name"] in AGG_SPARK_LIST:
408408
builtin_aggregates.append(agg)
@@ -416,52 +416,50 @@ def run_builtin_aggregators(builtin_aggregates, df, ctx, spark):
416416
agg_cols = []
417417
for agg in builtin_aggregates:
418418
aggregator = ctx.aggregators[agg["aggregator"]]
419-
input_repl = ctx.populate_values(
420-
agg["input"], aggregator["input"], preserve_column_refs=False
421-
)
419+
input = ctx.populate_values(agg["input"], aggregator["input"], preserve_column_refs=False)
422420

423421
if aggregator["name"] == "approx_count_distinct":
424422
agg_cols.append(
425-
F.approxCountDistinct(input_repl["col"], input_repl.get("rsd")).alias(agg["name"])
423+
F.approxCountDistinct(input["col"], input.get("rsd")).alias(agg["name"])
426424
)
427425
if aggregator["name"] == "avg":
428-
agg_cols.append(F.avg(input_repl).alias(agg["name"]))
426+
agg_cols.append(F.avg(input).alias(agg["name"]))
429427
if aggregator["name"] in {"collect_set_int", "collect_set_float", "collect_set_string"}:
430-
agg_cols.append(F.collect_set(input_repl).alias(agg["name"]))
428+
agg_cols.append(F.collect_set(input).alias(agg["name"]))
431429
if aggregator["name"] == "count":
432-
agg_cols.append(F.count(input_repl).alias(agg["name"]))
430+
agg_cols.append(F.count(input).alias(agg["name"]))
433431
if aggregator["name"] == "count_distinct":
434-
agg_cols.append(F.countDistinct(*input_repl).alias(agg["name"]))
432+
agg_cols.append(F.countDistinct(*input).alias(agg["name"]))
435433
if aggregator["name"] == "covar_pop":
436-
agg_cols.append(F.covar_pop(input_repl["col1"], input_repl["col2"]).alias(agg["name"]))
434+
agg_cols.append(F.covar_pop(input["col1"], input["col2"]).alias(agg["name"]))
437435
if aggregator["name"] == "covar_samp":
438-
agg_cols.append(F.covar_samp(input_repl["col1"], input_repl["col2"]).alias(agg["name"]))
436+
agg_cols.append(F.covar_samp(input["col1"], input["col2"]).alias(agg["name"]))
439437
if aggregator["name"] == "kurtosis":
440-
agg_cols.append(F.kurtosis(input_repl).alias(agg["name"]))
438+
agg_cols.append(F.kurtosis(input).alias(agg["name"]))
441439
if aggregator["name"] in {"max_int", "max_float", "max_string"}:
442-
agg_cols.append(F.max(input_repl).alias(agg["name"]))
440+
agg_cols.append(F.max(input).alias(agg["name"]))
443441
if aggregator["name"] == "mean":
444-
agg_cols.append(F.mean(input_repl).alias(agg["name"]))
442+
agg_cols.append(F.mean(input).alias(agg["name"]))
445443
if aggregator["name"] in {"min_int", "min_float", "min_string"}:
446-
agg_cols.append(F.min(input_repl).alias(agg["name"]))
444+
agg_cols.append(F.min(input).alias(agg["name"]))
447445
if aggregator["name"] == "skewness":
448-
agg_cols.append(F.skewness(input_repl).alias(agg["name"]))
446+
agg_cols.append(F.skewness(input).alias(agg["name"]))
449447
if aggregator["name"] == "stddev":
450-
agg_cols.append(F.stddev(input_repl).alias(agg["name"]))
448+
agg_cols.append(F.stddev(input).alias(agg["name"]))
451449
if aggregator["name"] == "stddev_pop":
452-
agg_cols.append(F.stddev_pop(input_repl).alias(agg["name"]))
450+
agg_cols.append(F.stddev_pop(input).alias(agg["name"]))
453451
if aggregator["name"] == "stddev_samp":
454-
agg_cols.append(F.stddev_samp(input_repl).alias(agg["name"]))
452+
agg_cols.append(F.stddev_samp(input).alias(agg["name"]))
455453
if aggregator["name"] in {"sum_int", "sum_float"}:
456-
agg_cols.append(F.sum(input_repl).alias(agg["name"]))
454+
agg_cols.append(F.sum(input).alias(agg["name"]))
457455
if aggregator["name"] in {"sum_distinct_int", "sum_distinct_float"}:
458-
agg_cols.append(F.sumDistinct(input_repl).alias(agg["name"]))
456+
agg_cols.append(F.sumDistinct(input).alias(agg["name"]))
459457
if aggregator["name"] == "var_pop":
460-
agg_cols.append(F.var_pop(input_repl).alias(agg["name"]))
458+
agg_cols.append(F.var_pop(input).alias(agg["name"]))
461459
if aggregator["name"] == "var_samp":
462-
agg_cols.append(F.var_samp(input_repl).alias(agg["name"]))
460+
agg_cols.append(F.var_samp(input).alias(agg["name"]))
463461
if aggregator["name"] == "variance":
464-
agg_cols.append(F.variance(input_repl).alias(agg["name"]))
462+
agg_cols.append(F.variance(input).alias(agg["name"]))
465463

466464
results = df.agg(*agg_cols).collect()[0].asDict()
467465

@@ -479,12 +477,10 @@ def run_builtin_aggregators(builtin_aggregates, df, ctx, spark):
479477
def run_custom_aggregator(aggregate, df, ctx, spark):
480478
aggregator = ctx.aggregators[aggregate["aggregator"]]
481479
aggregator_impl, _ = ctx.get_aggregator_impl(aggregate["name"])
482-
input_repl = ctx.populate_values(
483-
aggregate["input"], aggregator["input"], preserve_column_refs=False
484-
)
480+
input = ctx.populate_values(aggregate["input"], aggregator["input"], preserve_column_refs=False)
485481

486482
try:
487-
result = aggregator_impl.aggregate_spark(df, input_repl)
483+
result = aggregator_impl.aggregate_spark(df, input)
488484
except Exception as e:
489485
raise UserRuntimeException(
490486
"aggregate " + aggregate["name"],
@@ -517,11 +513,11 @@ def execute_transform_spark(column_name, df, ctx, spark):
517513
spark.sparkContext.addPyFile(trans_impl_path) # Executor pods need this because of the UDF
518514
ctx.spark_uploaded_impls[trans_impl_path] = True
519515

520-
input_repl = ctx.populate_values(
516+
input = ctx.populate_values(
521517
transformed_column["input"], transformer["input"], preserve_column_refs=False
522518
)
523519
try:
524-
return trans_impl.transform_spark(df, input_repl, column_name)
520+
return trans_impl.transform_spark(df, input, column_name)
525521
except Exception as e:
526522
raise UserRuntimeException("function transform_spark") from e
527523

@@ -532,7 +528,7 @@ def execute_transform_python(column_name, df, ctx, spark, validate=False):
532528
transformer = ctx.transformers[transformed_column["transformer"]]
533529

534530
input_cols_sorted = sorted(ctx.extract_column_names(transformed_column["input"]))
535-
input_repl = ctx.populate_values(
531+
input = ctx.populate_values(
536532
transformed_column["input"], transformer["input"], preserve_column_refs=True
537533
)
538534

@@ -541,9 +537,7 @@ def execute_transform_python(column_name, df, ctx, spark, validate=False):
541537
ctx.spark_uploaded_impls[trans_impl_path] = True
542538

543539
def _transform(*values):
544-
transformer_input = create_transformer_inputs_from_lists(
545-
input_repl, input_cols_sorted, values
546-
)
540+
transformer_input = create_transformer_inputs_from_lists(input, input_cols_sorted, values)
547541
return trans_impl.transform_python(transformer_input)
548542

549543
transform_python_func = _transform
@@ -593,15 +587,15 @@ def validate_transformer(column_name, test_df, ctx, spark):
593587
if transformer["output_type"] == consts.COLUMN_TYPE_INFERRED:
594588
sample_df = test_df.collect()
595589
sample = sample_df[0]
596-
input_repl = ctx.populate_values(
590+
input = ctx.populate_values(
597591
transformed_column["input"], transformer["input"], preserve_column_refs=True
598592
)
599-
transformer_input = create_transformer_inputs_from_map(input_repl, sample)
593+
transformer_input = create_transformer_inputs_from_map(input, sample)
600594
initial_transformed_sample = trans_impl.transform_python(transformer_input)
601595
inferred_python_type = infer_type(initial_transformed_sample)
602596

603597
for row in sample_df:
604-
transformer_input = create_transformer_inputs_from_map(input_repl, row)
598+
transformer_input = create_transformer_inputs_from_map(input, row)
605599
transformed_sample = trans_impl.transform_python(transformer_input)
606600
if inferred_python_type != infer_type(transformed_sample):
607601
raise UserRuntimeException(

0 commit comments

Comments
 (0)