diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 65346aa49d89f..fb9e11df18833 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -898,7 +898,7 @@ object JdbcUtils extends Logging with SQLConfHelper { case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n) case _ => df } - repartitionedDF.rdd.foreachPartition { iterator => savePartition( + repartitionedDF.foreachPartition { iterator => savePartition( table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, options) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 4903e31c49ca7..60785e339425e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -29,7 +29,7 @@ import org.mockito.ArgumentMatchers._ import org.mockito.Mockito._ import org.apache.spark.{SparkException, SparkSQLException} -import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, Observation, QueryTest, Row} import org.apache.spark.sql.catalyst.{analysis, TableIdentifier} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.ShowCreateTable @@ -39,6 +39,7 @@ import org.apache.spark.sql.execution.command.{ExplainCommand, ShowCreateTableCo import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRelation, JdbcUtils} import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper +import org.apache.spark.sql.functions.{lit, percentile_approx} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSparkSession @@ -2105,4 +2106,18 @@ class JDBCSuite extends QueryTest with SharedSparkSession { } } } + + test("SPARK-45475: saving a table via JDBC should work with observe API") { + val tableName = "test_table" + val namedObservation = Observation("named") + val observed_df = spark.range(100).observe( + namedObservation, percentile_approx($"id", lit(0.5), lit(100)).as("percentile_approx_val")) + + observed_df.write.format("jdbc") + .option("url", urlWithUserAndPass) + .option("dbtable", tableName).save() + + val expected = Map("percentile_approx_val" -> 49) + assert(namedObservation.get === expected) + } }