@@ -397,12 +397,12 @@ def read_parquet(ctx, spark):
397
397
398
398
399
399
def split_aggregators (aggregate_names , ctx ):
400
- aggregate_resources = [ctx .aggregates [agg_name ] for agg_name in aggregate_names ]
400
+ aggregates = [ctx .aggregates [agg_name ] for agg_name in aggregate_names ]
401
401
402
402
builtin_aggregates = []
403
403
custom_aggregates = []
404
404
405
- for agg in aggregate_resources :
405
+ for agg in aggregates :
406
406
aggregator = ctx .aggregators [agg ["aggregator" ]]
407
407
if aggregator .get ("namespace" , None ) == "cortex" and aggregator ["name" ] in AGG_SPARK_LIST :
408
408
builtin_aggregates .append (agg )
@@ -416,52 +416,50 @@ def run_builtin_aggregators(builtin_aggregates, df, ctx, spark):
416
416
agg_cols = []
417
417
for agg in builtin_aggregates :
418
418
aggregator = ctx .aggregators [agg ["aggregator" ]]
419
- input_repl = ctx .populate_values (
420
- agg ["input" ], aggregator ["input" ], preserve_column_refs = False
421
- )
419
+ input = ctx .populate_values (agg ["input" ], aggregator ["input" ], preserve_column_refs = False )
422
420
423
421
if aggregator ["name" ] == "approx_count_distinct" :
424
422
agg_cols .append (
425
- F .approxCountDistinct (input_repl ["col" ], input_repl .get ("rsd" )).alias (agg ["name" ])
423
+ F .approxCountDistinct (input ["col" ], input .get ("rsd" )).alias (agg ["name" ])
426
424
)
427
425
if aggregator ["name" ] == "avg" :
428
- agg_cols .append (F .avg (input_repl ).alias (agg ["name" ]))
426
+ agg_cols .append (F .avg (input ).alias (agg ["name" ]))
429
427
if aggregator ["name" ] in {"collect_set_int" , "collect_set_float" , "collect_set_string" }:
430
- agg_cols .append (F .collect_set (input_repl ).alias (agg ["name" ]))
428
+ agg_cols .append (F .collect_set (input ).alias (agg ["name" ]))
431
429
if aggregator ["name" ] == "count" :
432
- agg_cols .append (F .count (input_repl ).alias (agg ["name" ]))
430
+ agg_cols .append (F .count (input ).alias (agg ["name" ]))
433
431
if aggregator ["name" ] == "count_distinct" :
434
- agg_cols .append (F .countDistinct (* input_repl ).alias (agg ["name" ]))
432
+ agg_cols .append (F .countDistinct (* input ).alias (agg ["name" ]))
435
433
if aggregator ["name" ] == "covar_pop" :
436
- agg_cols .append (F .covar_pop (input_repl ["col1" ], input_repl ["col2" ]).alias (agg ["name" ]))
434
+ agg_cols .append (F .covar_pop (input ["col1" ], input ["col2" ]).alias (agg ["name" ]))
437
435
if aggregator ["name" ] == "covar_samp" :
438
- agg_cols .append (F .covar_samp (input_repl ["col1" ], input_repl ["col2" ]).alias (agg ["name" ]))
436
+ agg_cols .append (F .covar_samp (input ["col1" ], input ["col2" ]).alias (agg ["name" ]))
439
437
if aggregator ["name" ] == "kurtosis" :
440
- agg_cols .append (F .kurtosis (input_repl ).alias (agg ["name" ]))
438
+ agg_cols .append (F .kurtosis (input ).alias (agg ["name" ]))
441
439
if aggregator ["name" ] in {"max_int" , "max_float" , "max_string" }:
442
- agg_cols .append (F .max (input_repl ).alias (agg ["name" ]))
440
+ agg_cols .append (F .max (input ).alias (agg ["name" ]))
443
441
if aggregator ["name" ] == "mean" :
444
- agg_cols .append (F .mean (input_repl ).alias (agg ["name" ]))
442
+ agg_cols .append (F .mean (input ).alias (agg ["name" ]))
445
443
if aggregator ["name" ] in {"min_int" , "min_float" , "min_string" }:
446
- agg_cols .append (F .min (input_repl ).alias (agg ["name" ]))
444
+ agg_cols .append (F .min (input ).alias (agg ["name" ]))
447
445
if aggregator ["name" ] == "skewness" :
448
- agg_cols .append (F .skewness (input_repl ).alias (agg ["name" ]))
446
+ agg_cols .append (F .skewness (input ).alias (agg ["name" ]))
449
447
if aggregator ["name" ] == "stddev" :
450
- agg_cols .append (F .stddev (input_repl ).alias (agg ["name" ]))
448
+ agg_cols .append (F .stddev (input ).alias (agg ["name" ]))
451
449
if aggregator ["name" ] == "stddev_pop" :
452
- agg_cols .append (F .stddev_pop (input_repl ).alias (agg ["name" ]))
450
+ agg_cols .append (F .stddev_pop (input ).alias (agg ["name" ]))
453
451
if aggregator ["name" ] == "stddev_samp" :
454
- agg_cols .append (F .stddev_samp (input_repl ).alias (agg ["name" ]))
452
+ agg_cols .append (F .stddev_samp (input ).alias (agg ["name" ]))
455
453
if aggregator ["name" ] in {"sum_int" , "sum_float" }:
456
- agg_cols .append (F .sum (input_repl ).alias (agg ["name" ]))
454
+ agg_cols .append (F .sum (input ).alias (agg ["name" ]))
457
455
if aggregator ["name" ] in {"sum_distinct_int" , "sum_distinct_float" }:
458
- agg_cols .append (F .sumDistinct (input_repl ).alias (agg ["name" ]))
456
+ agg_cols .append (F .sumDistinct (input ).alias (agg ["name" ]))
459
457
if aggregator ["name" ] == "var_pop" :
460
- agg_cols .append (F .var_pop (input_repl ).alias (agg ["name" ]))
458
+ agg_cols .append (F .var_pop (input ).alias (agg ["name" ]))
461
459
if aggregator ["name" ] == "var_samp" :
462
- agg_cols .append (F .var_samp (input_repl ).alias (agg ["name" ]))
460
+ agg_cols .append (F .var_samp (input ).alias (agg ["name" ]))
463
461
if aggregator ["name" ] == "variance" :
464
- agg_cols .append (F .variance (input_repl ).alias (agg ["name" ]))
462
+ agg_cols .append (F .variance (input ).alias (agg ["name" ]))
465
463
466
464
results = df .agg (* agg_cols ).collect ()[0 ].asDict ()
467
465
@@ -479,12 +477,10 @@ def run_builtin_aggregators(builtin_aggregates, df, ctx, spark):
479
477
def run_custom_aggregator (aggregate , df , ctx , spark ):
480
478
aggregator = ctx .aggregators [aggregate ["aggregator" ]]
481
479
aggregator_impl , _ = ctx .get_aggregator_impl (aggregate ["name" ])
482
- input_repl = ctx .populate_values (
483
- aggregate ["input" ], aggregator ["input" ], preserve_column_refs = False
484
- )
480
+ input = ctx .populate_values (aggregate ["input" ], aggregator ["input" ], preserve_column_refs = False )
485
481
486
482
try :
487
- result = aggregator_impl .aggregate_spark (df , input_repl )
483
+ result = aggregator_impl .aggregate_spark (df , input )
488
484
except Exception as e :
489
485
raise UserRuntimeException (
490
486
"aggregate " + aggregate ["name" ],
@@ -517,11 +513,11 @@ def execute_transform_spark(column_name, df, ctx, spark):
517
513
spark .sparkContext .addPyFile (trans_impl_path ) # Executor pods need this because of the UDF
518
514
ctx .spark_uploaded_impls [trans_impl_path ] = True
519
515
520
- input_repl = ctx .populate_values (
516
+ input = ctx .populate_values (
521
517
transformed_column ["input" ], transformer ["input" ], preserve_column_refs = False
522
518
)
523
519
try :
524
- return trans_impl .transform_spark (df , input_repl , column_name )
520
+ return trans_impl .transform_spark (df , input , column_name )
525
521
except Exception as e :
526
522
raise UserRuntimeException ("function transform_spark" ) from e
527
523
@@ -532,7 +528,7 @@ def execute_transform_python(column_name, df, ctx, spark, validate=False):
532
528
transformer = ctx .transformers [transformed_column ["transformer" ]]
533
529
534
530
input_cols_sorted = sorted (ctx .extract_column_names (transformed_column ["input" ]))
535
- input_repl = ctx .populate_values (
531
+ input = ctx .populate_values (
536
532
transformed_column ["input" ], transformer ["input" ], preserve_column_refs = True
537
533
)
538
534
@@ -541,9 +537,7 @@ def execute_transform_python(column_name, df, ctx, spark, validate=False):
541
537
ctx .spark_uploaded_impls [trans_impl_path ] = True
542
538
543
539
def _transform (* values ):
544
- transformer_input = create_transformer_inputs_from_lists (
545
- input_repl , input_cols_sorted , values
546
- )
540
+ transformer_input = create_transformer_inputs_from_lists (input , input_cols_sorted , values )
547
541
return trans_impl .transform_python (transformer_input )
548
542
549
543
transform_python_func = _transform
@@ -593,15 +587,15 @@ def validate_transformer(column_name, test_df, ctx, spark):
593
587
if transformer ["output_type" ] == consts .COLUMN_TYPE_INFERRED :
594
588
sample_df = test_df .collect ()
595
589
sample = sample_df [0 ]
596
- input_repl = ctx .populate_values (
590
+ input = ctx .populate_values (
597
591
transformed_column ["input" ], transformer ["input" ], preserve_column_refs = True
598
592
)
599
- transformer_input = create_transformer_inputs_from_map (input_repl , sample )
593
+ transformer_input = create_transformer_inputs_from_map (input , sample )
600
594
initial_transformed_sample = trans_impl .transform_python (transformer_input )
601
595
inferred_python_type = infer_type (initial_transformed_sample )
602
596
603
597
for row in sample_df :
604
- transformer_input = create_transformer_inputs_from_map (input_repl , row )
598
+ transformer_input = create_transformer_inputs_from_map (input , row )
605
599
transformed_sample = trans_impl .transform_python (transformer_input )
606
600
if inferred_python_type != infer_type (transformed_sample ):
607
601
raise UserRuntimeException (
0 commit comments