Skip to content

Commit

Permalink
try spark 4.0 SNAPSHOT
Browse files Browse the repository at this point in the history
  • Loading branch information
huaxingao committed Oct 8, 2024
1 parent a76373d commit 90dbee7
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 19 deletions.
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ snowflake-jdbc = "3.18.0"
spark-hive33 = "3.3.4"
spark-hive34 = "3.4.3"
spark-hive35 = "3.5.2"
spark-hive40 = "4.0.0-preview1"
spark-hive40 = "4.0.0-SNAPSHOT"
spring-boot = "2.7.18"
spring-web = "5.3.39"
sqlite-jdbc = "3.46.1.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,13 @@ case class ResolveViews(spark: SparkSession) extends Rule[LogicalPlan] with Look
private def qualifyFunctionIdentifiers(
plan: LogicalPlan,
catalogAndNamespace: Seq[String]): LogicalPlan = plan transformExpressions {
case u@UnresolvedFunction(Seq(name), _, _, _, _, _) =>
case u@UnresolvedFunction(Seq(name), _, _, _, _, _, _) =>
if (!isBuiltinFunction(name)) {
u.copy(nameParts = catalogAndNamespace :+ name)
} else {
u
}
case u@UnresolvedFunction(parts, _, _, _, _, _) if !isCatalog(parts.head) =>
case u@UnresolvedFunction(parts, _, _, _, _, _, _) if !isCatalog(parts.head) =>
u.copy(nameParts = catalogAndNamespace.head +: parts)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ case class RewriteViewCommands(spark: SparkSession) extends Rule[LogicalPlan] wi
private def collectTemporaryFunctions(child: LogicalPlan): Seq[String] = {
val tempFunctions = new mutable.HashSet[String]()
child.resolveExpressionsWithPruning(_.containsAnyPattern(UNRESOLVED_FUNCTION)) {
case f @ UnresolvedFunction(nameParts, _, _, _, _, _) if isTempFunction(nameParts) =>
case f @ UnresolvedFunction(nameParts, _, _, _, _, _, _) if isTempFunction(nameParts) =>
tempFunctions += nameParts.head
f
case e: SubqueryExpression =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.RewriteViewCommands
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.CompoundBody
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser.NonReservedContext
import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser.QuotedIdentifierContext
Expand Down Expand Up @@ -101,6 +102,10 @@ class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) extends ParserI
delegate.parseTableSchema(sqlText)
}

override def parseScript(sqlScriptText: String): CompoundBody = {
delegate.parseScript(sqlScriptText)
}

override def parseSortOrder(sqlText: String): java.util.List[RawOrderField] = {
val fields = parse(sqlText) { parser => astBuilder.visitSingleOrder(parser.singleOrder()) }
fields.map { field =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,22 @@ case class ShowV2ViewsExec(
val rows = new ArrayBuffer[InternalRow]()

// handle GLOBAL VIEWS
val globalTemp = session.sessionState.catalog.globalTempViewManager.database
if (namespace.nonEmpty && globalTemp == namespace.head) {
pattern.map(p => session.sessionState.catalog.globalTempViewManager.listViewNames(p))
.getOrElse(session.sessionState.catalog.globalTempViewManager.listViewNames("*"))
.map(name => rows += toCatalystRow(globalTemp, name, true))
} else {
// TODO: GlobalTempViewManager.database is not accessible any more.
// Comment out the global views handling for now until
// GlobalTempViewManager.database can be accessed again.
// val globalTemp = session.sessionState.catalog.globalTempViewManager.database
// if (namespace.nonEmpty && globalTemp == namespace.head) {
// pattern.map(p => session.sessionState.catalog.globalTempViewManager.listViewNames(p))
// .getOrElse(session.sessionState.catalog.globalTempViewManager.listViewNames("*"))
// .map(name => rows += toCatalystRow(globalTemp, name, true))
// } else {
val views = catalog.listViews(namespace: _*)
views.map { view =>
if (pattern.map(StringUtils.filterPattern(Seq(view.name()), _).nonEmpty).getOrElse(true)) {
rows += toCatalystRow(view.namespace().quoted, view.name(), false)
}
}
}
// }

// include TEMP VIEWS
pattern.map(p => session.sessionState.catalog.listLocalTempViews(p))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@
import org.apache.iceberg.Parameters;
import org.apache.iceberg.RowLevelOperationMode;
import org.apache.iceberg.TableProperties;
import org.apache.iceberg.exceptions.AlreadyExistsException;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.spark.SparkCatalogConfig;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException;
import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke;
Expand Down Expand Up @@ -228,7 +230,12 @@ private void checkDelete(RowLevelOperationMode mode, String cond) {
DistributionMode.NONE.modeName());

Dataset<Row> changeDF = spark.table(tableName).where(cond).limit(2).select("id");
changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create();
try {
changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create();
} catch (TableAlreadyExistsException e) {
throw new AlreadyExistsException(
"Cannot create table %s as it already exists", CHANGES_TABLE_NAME);
}

List<Expression> calls =
executeAndCollectFunctionCalls(
Expand Down Expand Up @@ -260,7 +267,12 @@ private void checkUpdate(RowLevelOperationMode mode, String cond) {
DistributionMode.NONE.modeName());

Dataset<Row> changeDF = spark.table(tableName).where(cond).limit(2).select("id");
changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create();
try {
changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create();
} catch (TableAlreadyExistsException e) {
throw new AlreadyExistsException(
"Cannot create table %s as it already exists", CHANGES_TABLE_NAME);
}

List<Expression> calls =
executeAndCollectFunctionCalls(
Expand Down Expand Up @@ -291,7 +303,12 @@ private void checkMerge(RowLevelOperationMode mode, String cond) {

Dataset<Row> changeDF =
spark.table(tableName).where(cond).limit(2).selectExpr("id + 1 as id");
changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create();
try {
changeDF.coalesce(1).writeTo(tableName(CHANGES_TABLE_NAME)).create();
} catch (TableAlreadyExistsException e) {
throw new AlreadyExistsException(
"Cannot create table %s as it already exists", CHANGES_TABLE_NAME);
}

List<Expression> calls =
executeAndCollectFunctionCalls(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.sql.RuntimeConfig;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.internal.RuntimeConfigImpl;
import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;

Expand All @@ -40,7 +41,7 @@ class SparkConfParser {

SparkConfParser() {
this.properties = ImmutableMap.of();
this.sessionConf = new RuntimeConfig(SQLConf.get());
this.sessionConf = new RuntimeConfigImpl(SQLConf.get());
this.options = CaseInsensitiveStringMap.empty();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.internal.ExpressionColumnNode;
import org.apache.spark.sql.stats.ThetaSketchAgg;

public class NDVSketchUtil {
Expand Down Expand Up @@ -84,6 +85,6 @@ private static Column[] toAggColumns(List<String> colNames) {

private static Column toAggColumn(String colName) {
ThetaSketchAgg agg = new ThetaSketchAgg(colName);
return new Column(agg.toAggregateExpression());
return new Column(ExpressionColumnNode.apply(agg.toAggregateExpression()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ private static Long propertyAsLong(CaseInsensitiveStringMap options, String prop
}

private static void setupDefaultSparkCatalogs(SparkSession spark) {
if (!spark.conf().contains(DEFAULT_CATALOG)) {
if (spark.conf().getOption(DEFAULT_CATALOG).isEmpty()) {
ImmutableMap<String, String> config =
ImmutableMap.of(
"type", "hive",
Expand All @@ -241,7 +241,7 @@ private static void setupDefaultSparkCatalogs(SparkSession spark) {
config.forEach((key, value) -> spark.conf().set(DEFAULT_CATALOG + "." + key, value));
}

if (!spark.conf().contains(DEFAULT_CACHE_CATALOG)) {
if (spark.conf().getOption(DEFAULT_CACHE_CATALOG).isEmpty()) {
spark.conf().set(DEFAULT_CACHE_CATALOG, SparkCachedTableCatalog.class.getName());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@ import org.apache.datasketches.theta.Sketch
import org.apache.datasketches.theta.UpdateSketch
import org.apache.iceberg.spark.SparkSchemaUtil
import org.apache.iceberg.types.Conversions
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.ColumnNode
import org.apache.spark.sql.internal.ExpressionColumnNode
import org.apache.spark.sql.types.BinaryType
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.Decimal
Expand All @@ -56,7 +59,7 @@ case class ThetaSketchAgg(
private lazy val icebergType = SparkSchemaUtil.convert(child.dataType)

def this(colName: String) = {
this(col(colName).expr, 0, 0)
this(ThetaSketchAgg.expr(col(colName).node), 0, 0)
}

override def dataType: DataType = BinaryType
Expand Down Expand Up @@ -119,3 +122,12 @@ case class ThetaSketchAgg(
compactSketch.toByteArray
}
}

object ThetaSketchAgg {
def expr(node: ColumnNode): Expression = {
node match {
case ExpressionColumnNode(expression, _) => expression
case node => throw SparkException.internalError("Unsupported ColumnNode: " + node)
}
}
}

0 comments on commit 90dbee7

Please sign in to comment.