Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion native-engine/auron-serde/proto/auron.proto
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ enum ScalarFunction {
Nvl2=83;
Least=84;
Greatest=85;
SparkExtFunctions=10000;
AuronExtFunctions=10000;
}

message PhysicalScalarFunctionNode {
Expand Down
6 changes: 3 additions & 3 deletions native-engine/auron-serde/src/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ impl From<protobuf::ScalarFunction> for Arc<ScalarUDF> {
ScalarFunction::Power => f::math::power(),
ScalarFunction::IsNaN => f::math::isnan(),

ScalarFunction::SparkExtFunctions => {
ScalarFunction::AuronExtFunctions => {
unreachable!()
}
}
Expand Down Expand Up @@ -945,9 +945,9 @@ fn try_parse_physical_expr(
.map(|x| try_parse_physical_expr(x, input_schema))
.collect::<Result<Vec<_>, _>>()?;

let scalar_udf = if scalar_function == protobuf::ScalarFunction::SparkExtFunctions {
let scalar_udf = if scalar_function == protobuf::ScalarFunction::AuronExtFunctions {
let fun_name = &e.name;
let fun = datafusion_ext_functions::create_spark_ext_function(fun_name)?;
let fun = datafusion_ext_functions::create_auron_ext_function(fun_name)?;
Arc::new(create_udf(
&format!("spark_ext_function_{}", fun_name),
args.iter()
Expand Down
73 changes: 39 additions & 34 deletions native-engine/datafusion-ext-functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,43 +32,48 @@ mod spark_sha2;
mod spark_strings;
mod spark_unscaled_value;

pub fn create_spark_ext_function(name: &str) -> Result<ScalarFunctionImplementation> {
pub fn create_auron_ext_function(name: &str) -> Result<ScalarFunctionImplementation> {
// auron ext functions, if used for spark should be start with 'Spark_',
// if used for flink should be start with 'Flink_',
// same to other engines.
Ok(match name {
"Placeholder" => Arc::new(|_| panic!("placeholder() should never be called")),
"NullIf" => Arc::new(spark_null_if::spark_null_if),
"NullIfZero" => Arc::new(spark_null_if::spark_null_if_zero),
"UnscaledValue" => Arc::new(spark_unscaled_value::spark_unscaled_value),
"MakeDecimal" => Arc::new(spark_make_decimal::spark_make_decimal),
"CheckOverflow" => Arc::new(spark_check_overflow::spark_check_overflow),
"Murmur3Hash" => Arc::new(spark_hash::spark_murmur3_hash),
"XxHash64" => Arc::new(spark_hash::spark_xxhash64),
"Sha224" => Arc::new(spark_sha2::spark_sha224),
"Sha256" => Arc::new(spark_sha2::spark_sha256),
"Sha384" => Arc::new(spark_sha2::spark_sha384),
"Sha512" => Arc::new(spark_sha2::spark_sha512),
"GetJsonObject" => Arc::new(spark_get_json_object::spark_get_json_object),
"GetParsedJsonObject" => Arc::new(spark_get_json_object::spark_get_parsed_json_object),
"ParseJson" => Arc::new(spark_get_json_object::spark_parse_json),
"MakeArray" => Arc::new(spark_make_array::array),
"StringSpace" => Arc::new(spark_strings::string_space),
"StringRepeat" => Arc::new(spark_strings::string_repeat),
"StringSplit" => Arc::new(spark_strings::string_split),
"StringConcat" => Arc::new(spark_strings::string_concat),
"StringConcatWs" => Arc::new(spark_strings::string_concat_ws),
"StringLower" => Arc::new(spark_strings::string_lower),
"StringUpper" => Arc::new(spark_strings::string_upper),
"Year" => Arc::new(spark_dates::spark_year),
"Month" => Arc::new(spark_dates::spark_month),
"Day" => Arc::new(spark_dates::spark_day),
"Quarter" => Arc::new(spark_dates::spark_quarter),
"Hour" => Arc::new(spark_dates::spark_hour),
"Minute" => Arc::new(spark_dates::spark_minute),
"Second" => Arc::new(spark_dates::spark_second),
"BrickhouseArrayUnion" => Arc::new(brickhouse::array_union::array_union),
"Round" => Arc::new(spark_round::spark_round),
"NormalizeNanAndZero" => {
"Spark_NullIf" => Arc::new(spark_null_if::spark_null_if),
"Spark_NullIfZero" => Arc::new(spark_null_if::spark_null_if_zero),
"Spark_UnscaledValue" => Arc::new(spark_unscaled_value::spark_unscaled_value),
"Spark_MakeDecimal" => Arc::new(spark_make_decimal::spark_make_decimal),
"Spark_CheckOverflow" => Arc::new(spark_check_overflow::spark_check_overflow),
"Spark_Murmur3Hash" => Arc::new(spark_hash::spark_murmur3_hash),
"Spark_XxHash64" => Arc::new(spark_hash::spark_xxhash64),
"Spark_Sha224" => Arc::new(spark_sha2::spark_sha224),
"Spark_Sha256" => Arc::new(spark_sha2::spark_sha256),
"Spark_Sha384" => Arc::new(spark_sha2::spark_sha384),
"Spark_Sha512" => Arc::new(spark_sha2::spark_sha512),
"Spark_GetJsonObject" => Arc::new(spark_get_json_object::spark_get_json_object),
"Spark_GetParsedJsonObject" => {
Arc::new(spark_get_json_object::spark_get_parsed_json_object)
}
"Spark_ParseJson" => Arc::new(spark_get_json_object::spark_parse_json),
"Spark_MakeArray" => Arc::new(spark_make_array::array),
"Spark_StringSpace" => Arc::new(spark_strings::string_space),
"Spark_StringRepeat" => Arc::new(spark_strings::string_repeat),
"Spark_StringSplit" => Arc::new(spark_strings::string_split),
"Spark_StringConcat" => Arc::new(spark_strings::string_concat),
"Spark_StringConcatWs" => Arc::new(spark_strings::string_concat_ws),
"Spark_StringLower" => Arc::new(spark_strings::string_lower),
"Spark_StringUpper" => Arc::new(spark_strings::string_upper),
"Spark_Year" => Arc::new(spark_dates::spark_year),
"Spark_Month" => Arc::new(spark_dates::spark_month),
"Spark_Day" => Arc::new(spark_dates::spark_day),
"Spark_Quarter" => Arc::new(spark_dates::spark_quarter),
"Spark_Hour" => Arc::new(spark_dates::spark_hour),
"Spark_Minute" => Arc::new(spark_dates::spark_minute),
"Spark_Second" => Arc::new(spark_dates::spark_second),
"Spark_BrickhouseArrayUnion" => Arc::new(brickhouse::array_union::array_union),
"Spark_Round" => Arc::new(spark_round::spark_round),
"Spark_NormalizeNanAndZero" => {
Arc::new(spark_normalize_nan_and_zero::spark_normalize_nan_and_zero)
}
_ => df_unimplemented_err!("spark ext function not implemented: {name}")?,
_ => df_unimplemented_err!("auron ext function not implemented: {name}")?,
})
}
Original file line number Diff line number Diff line change
Expand Up @@ -528,8 +528,8 @@ class ShimsImpl extends Shims with Logging {
.setScalarFunction(
pb.PhysicalScalarFunctionNode
.newBuilder()
.setFun(pb.ScalarFunction.SparkExtFunctions)
.setName("StringSplit")
.setFun(pb.ScalarFunction.AuronExtFunctions)
.setName("Spark_StringSplit")
.addArgs(NativeConverters.convertExprWithFallback(str, isPruningExpr, fallback))
.addArgs(NativeConverters
.convertExprWithFallback(Literal(nativePat), isPruningExpr, fallback))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ object NativeConverters extends Logging {
pb.PhysicalBinaryExprNode
.newBuilder()
.setL(convertExprWithFallback(Cast(lhs, resultType), isPruningExpr, fallback))
.setR(buildExtScalarFunction("NullIfZero", rhs :: Nil, rhs.dataType))
.setR(buildExtScalarFunction("Spark_NullIfZero", rhs :: Nil, rhs.dataType))
.setOp("Divide"))
}))
}
Expand All @@ -711,7 +711,7 @@ object NativeConverters extends Logging {
pb.PhysicalBinaryExprNode
.newBuilder()
.setL(convertExprWithFallback(lhsCasted, isPruningExpr, fallback))
.setR(buildExtScalarFunction("NullIfZero", rhsCasted :: Nil, rhs.dataType))
.setR(buildExtScalarFunction("Spark_NullIfZero", rhsCasted :: Nil, rhs.dataType))
.setOp("Divide"))
}
}
Expand All @@ -733,7 +733,8 @@ object NativeConverters extends Logging {
pb.PhysicalBinaryExprNode
.newBuilder()
.setL(convertExprWithFallback(lhsCasted, isPruningExpr, fallback))
.setR(buildExtScalarFunction("NullIfZero", rhsCasted :: Nil, rhs.dataType))
.setR(
buildExtScalarFunction("Spark_NullIfZero", rhsCasted :: Nil, rhs.dataType))
.setOp("Modulo"))
}
}
Expand Down Expand Up @@ -832,9 +833,9 @@ object NativeConverters extends Logging {
case e: Round =>
e.scale match {
case Literal(n: Int, _) =>
buildExtScalarFunction("Round", Seq(e.child, Literal(n.toLong)), e.dataType)
buildExtScalarFunction("Spark_Round", Seq(e.child, Literal(n.toLong)), e.dataType)
case _ =>
buildExtScalarFunction("Round", Seq(e.child, Literal(0L)), e.dataType)
buildExtScalarFunction("Spark_Round", Seq(e.child, Literal(0L)), e.dataType)
}

case e: Signum => buildScalarFunction(pb.ScalarFunction.Signum, e.children, e.dataType)
Expand All @@ -849,10 +850,10 @@ object NativeConverters extends Logging {

case e: Lower
if sparkAuronConfig.getBoolean(SparkAuronConfiguration.CASE_CONVERT_FUNCTIONS_ENABLE) =>
buildExtScalarFunction("StringLower", e.children, e.dataType)
buildExtScalarFunction("Spark_StringLower", e.children, e.dataType)
case e: Upper
if sparkAuronConfig.getBoolean(SparkAuronConfiguration.CASE_CONVERT_FUNCTIONS_ENABLE) =>
buildExtScalarFunction("StringUpper", e.children, e.dataType)
buildExtScalarFunction("Spark_StringLower", e.children, e.dataType)

case e: StringTrim =>
buildScalarFunction(pb.ScalarFunction.Trim, e.srcStr +: e.trimStr.toSeq, e.dataType)
Expand All @@ -861,48 +862,48 @@ object NativeConverters extends Logging {
case e: StringTrimRight =>
buildScalarFunction(pb.ScalarFunction.Rtrim, e.srcStr +: e.trimStr.toSeq, e.dataType)
case e @ NullIf(left, right, _) =>
buildExtScalarFunction("NullIf", left :: right :: Nil, e.dataType)
buildExtScalarFunction("Spark_NullIf", left :: right :: Nil, e.dataType)
case Md5(_1) =>
buildScalarFunction(pb.ScalarFunction.MD5, Seq(unpackBinaryTypeCast(_1)), StringType)
case Reverse(_1) =>
buildScalarFunction(pb.ScalarFunction.Reverse, Seq(unpackBinaryTypeCast(_1)), StringType)
case InitCap(_1) =>
buildScalarFunction(pb.ScalarFunction.InitCap, Seq(unpackBinaryTypeCast(_1)), StringType)
case Sha2(_1, Literal(224, _)) =>
buildExtScalarFunction("Sha224", Seq(unpackBinaryTypeCast(_1)), StringType)
buildExtScalarFunction("Spark_Sha224", Seq(unpackBinaryTypeCast(_1)), StringType)
case Sha2(_1, Literal(0, _)) =>
buildExtScalarFunction("Sha256", Seq(unpackBinaryTypeCast(_1)), StringType)
buildExtScalarFunction("Spark_Sha256", Seq(unpackBinaryTypeCast(_1)), StringType)
case Sha2(_1, Literal(256, _)) =>
buildExtScalarFunction("Sha256", Seq(unpackBinaryTypeCast(_1)), StringType)
buildExtScalarFunction("Spark_Sha256", Seq(unpackBinaryTypeCast(_1)), StringType)
case Sha2(_1, Literal(384, _)) =>
buildExtScalarFunction("Sha384", Seq(unpackBinaryTypeCast(_1)), StringType)
buildExtScalarFunction("Spark_Sha384", Seq(unpackBinaryTypeCast(_1)), StringType)
case Sha2(_1, Literal(512, _)) =>
buildExtScalarFunction("Sha512", Seq(unpackBinaryTypeCast(_1)), StringType)
buildExtScalarFunction("Spark_Sha512", Seq(unpackBinaryTypeCast(_1)), StringType)
case Murmur3Hash(children, 42) =>
buildExtScalarFunction("Murmur3Hash", children, IntegerType)
buildExtScalarFunction("Spark_Murmur3Hash", children, IntegerType)
case XxHash64(children, 42L) =>
buildExtScalarFunction("XxHash64", children, LongType)
buildExtScalarFunction("Spark_XxHash64", children, LongType)
case e: Greatest =>
buildScalarFunction(pb.ScalarFunction.Greatest, e.children, e.dataType)
case e: Pow =>
buildScalarFunction(pb.ScalarFunction.Power, e.children, e.dataType)
case e: Nvl =>
buildScalarFunction(pb.ScalarFunction.Nvl, e.children, e.dataType)

case Year(child) => buildExtScalarFunction("Year", child :: Nil, IntegerType)
case Month(child) => buildExtScalarFunction("Month", child :: Nil, IntegerType)
case DayOfMonth(child) => buildExtScalarFunction("Day", child :: Nil, IntegerType)
case Quarter(child) => buildExtScalarFunction("Quarter", child :: Nil, IntegerType)
case Year(child) => buildExtScalarFunction("Spark_Year", child :: Nil, IntegerType)
case Month(child) => buildExtScalarFunction("Spark_Month", child :: Nil, IntegerType)
case DayOfMonth(child) => buildExtScalarFunction("Spark_Day", child :: Nil, IntegerType)
case Quarter(child) => buildExtScalarFunction("Spark_Quarter", child :: Nil, IntegerType)

case e: Levenshtein =>
buildScalarFunction(pb.ScalarFunction.Levenshtein, e.children, e.dataType)

case e: Hour if datetimeExtractEnabled =>
buildTimePartExt("Hour", e.children.head, isPruningExpr, fallback)
buildTimePartExt("Spark_Hour", e.children.head, isPruningExpr, fallback)
case e: Minute if datetimeExtractEnabled =>
buildTimePartExt("Minute", e.children.head, isPruningExpr, fallback)
buildTimePartExt("Spark_Minute", e.children.head, isPruningExpr, fallback)
case e: Second if datetimeExtractEnabled =>
buildTimePartExt("Second", e.children.head, isPruningExpr, fallback)
buildTimePartExt("Spark_Second", e.children.head, isPruningExpr, fallback)

// startswith is converted to scalar function in pruning-expr mode
case StartsWith(expr, Literal(prefix, StringType)) if isPruningExpr =>
Expand Down Expand Up @@ -949,20 +950,20 @@ object NativeConverters extends Logging {
StringType)

case StringSpace(n) =>
buildExtScalarFunction("StringSpace", n :: Nil, StringType)
buildExtScalarFunction("Spark_StringSpace", n :: Nil, StringType)

case StringRepeat(str, n @ Literal(_, IntegerType)) =>
buildExtScalarFunction("StringRepeat", str :: n :: Nil, StringType)
buildExtScalarFunction("Spark_StringRepeat", str :: n :: Nil, StringType)

case e: Concat if e.children.forall(_.dataType == StringType) =>
buildExtScalarFunction("StringConcat", e.children, e.dataType)
buildExtScalarFunction("Spark_StringConcat", e.children, e.dataType)

case e: ConcatWs
if e.children.nonEmpty
&& e.children.head.isInstanceOf[Literal]
&& e.children.forall(c =>
c.dataType == StringType || c.dataType == ArrayType(StringType)) =>
buildExtScalarFunction("StringConcatWs", e.children, e.dataType)
buildExtScalarFunction("Spark_StringConcatWs", e.children, e.dataType)

case e: Coalesce =>
val children = e.children.map(Cast(_, e.dataType))
Expand Down Expand Up @@ -1011,29 +1012,29 @@ object NativeConverters extends Logging {
// expressions for DecimalPrecision rule
case UnscaledValue(_1) if decimalArithOpEnabled =>
val args = _1 :: Nil
buildExtScalarFunction("UnscaledValue", args, LongType)
buildExtScalarFunction("Spark_UnscaledValue", args, LongType)

case e: MakeDecimal if decimalArithOpEnabled =>
val precision = e.precision
val scale = e.scale
val args =
e.child :: Literal
.apply(precision, IntegerType) :: Literal.apply(scale, IntegerType) :: Nil
buildExtScalarFunction("MakeDecimal", args, DecimalType(precision, scale))
buildExtScalarFunction("Spark_MakeDecimal", args, DecimalType(precision, scale))

case e: CheckOverflow if decimalArithOpEnabled =>
val precision = e.dataType.precision
val scale = e.dataType.scale
val args =
e.child :: Literal
.apply(precision, IntegerType) :: Literal.apply(scale, IntegerType) :: Nil
buildExtScalarFunction("CheckOverflow", args, DecimalType(precision, scale))
buildExtScalarFunction("Spark_CheckOverflow", args, DecimalType(precision, scale))

case e: NormalizeNaNAndZero
if e.dataType.isInstanceOf[FloatType] || e.dataType.isInstanceOf[DoubleType] =>
buildExtScalarFunction("NormalizeNanAndZero", e.children, e.dataType)
buildExtScalarFunction("Spark_NormalizeNanAndZero", e.children, e.dataType)

case e: CreateArray => buildExtScalarFunction("MakeArray", e.children, e.dataType)
case e: CreateArray => buildExtScalarFunction("Spark_MakeArray", e.children, e.dataType)

case e: CreateNamedStruct =>
buildExprNode {
Expand Down Expand Up @@ -1100,16 +1101,19 @@ object NativeConverters extends Logging {
// The benefit of this approach is that if there are multiple calls,
// the JSON object can be reused, which can significantly improve performance.
val parsed = Shims.get.createNativeExprWrapper(
buildExtScalarFunction("ParseJson", e.children(0) :: Nil, BinaryType),
buildExtScalarFunction("Spark_ParseJson", e.children(0) :: Nil, BinaryType),
BinaryType,
nullable = false)
buildExtScalarFunction("GetParsedJsonObject", parsed :: e.children(1) :: Nil, StringType)
buildExtScalarFunction(
"Spark_GetParsedJsonObject",
parsed :: e.children(1) :: Nil,
StringType)

// hive UDF brickhouse.array_union
case e
if getFunctionClassName(e).contains("brickhouse.udf.collect.ArrayUnionUDF")
&& udfBrickHouseEnabled =>
buildExtScalarFunction("BrickhouseArrayUnion", e.children, e.dataType)
buildExtScalarFunction("Spark_BrickhouseArrayUnion", e.children, e.dataType)

case e =>
Shims.get.convertMoreExprWithFallback(e, isPruningExpr, fallback) match {
Expand Down Expand Up @@ -1304,7 +1308,7 @@ object NativeConverters extends Logging {
pb.PhysicalScalarFunctionNode
.newBuilder()
.setName(name)
.setFun(pb.ScalarFunction.SparkExtFunctions)
.setFun(pb.ScalarFunction.AuronExtFunctions)
.addAllArgs(
args.map(expr => convertExprWithFallback(expr, isPruningExpr, fallback)).asJava)
.setReturnType(convertDataType(dataType)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ object NativeAggBase extends Logging {
.setScalarFunction(
pb.PhysicalScalarFunctionNode
.newBuilder()
.setFun(pb.ScalarFunction.SparkExtFunctions)
.setFun(pb.ScalarFunction.AuronExtFunctions)
.setName("Placeholder")
.setReturnType(NativeConverters.convertDataType(e.dataType)))
.build()
Expand Down
Loading