Skip to content

Commit d2d751b

Browse files
committed
Merge remote-tracking branch 'upstream/master' into refreshInsertIntoHiveTable
2 parents 203e36c + ad0dada commit d2d751b

File tree

14 files changed

+558
-43
lines changed

14 files changed

+558
-43
lines changed

R/pkg/R/SQLContext.R

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,11 @@ getDefaultSqlSource <- function() {
184184
#'
185185
#' Converts R data.frame or list into SparkDataFrame.
186186
#'
187-
#' @param data an RDD or list or data.frame.
187+
#' @param data a list or data.frame.
188188
#' @param schema a list of column names or named list (StructType), optional.
189+
#' @param samplingRatio Currently not used.
190+
#' @param numPartitions the number of partitions of the SparkDataFrame. Defaults to 1, this is
191+
#' limited by length of the list or number of rows of the data.frame
189192
#' @return A SparkDataFrame.
190193
#' @rdname createDataFrame
191194
#' @export
@@ -195,12 +198,14 @@ getDefaultSqlSource <- function() {
195198
#' df1 <- as.DataFrame(iris)
196199
#' df2 <- as.DataFrame(list(3,4,5,6))
197200
#' df3 <- createDataFrame(iris)
201+
#' df4 <- createDataFrame(cars, numPartitions = 2)
198202
#' }
199203
#' @name createDataFrame
200204
#' @method createDataFrame default
201205
#' @note createDataFrame since 1.4.0
202206
# TODO(davies): support sampling and infer type from NA
203-
createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
207+
createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0,
208+
numPartitions = NULL) {
204209
sparkSession <- getSparkSession()
205210

206211
if (is.data.frame(data)) {
@@ -233,7 +238,11 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
233238

234239
if (is.list(data)) {
235240
sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
236-
rdd <- parallelize(sc, data)
241+
if (!is.null(numPartitions)) {
242+
rdd <- parallelize(sc, data, numSlices = numToInt(numPartitions))
243+
} else {
244+
rdd <- parallelize(sc, data, numSlices = 1)
245+
}
237246
} else if (inherits(data, "RDD")) {
238247
rdd <- data
239248
} else {
@@ -283,14 +292,13 @@ createDataFrame <- function(x, ...) {
283292
dispatchFunc("createDataFrame(data, schema = NULL)", x, ...)
284293
}
285294

286-
#' @param samplingRatio Currently not used.
287295
#' @rdname createDataFrame
288296
#' @aliases createDataFrame
289297
#' @export
290298
#' @method as.DataFrame default
291299
#' @note as.DataFrame since 1.6.0
292-
as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
293-
createDataFrame(data, schema)
300+
as.DataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0, numPartitions = NULL) {
301+
createDataFrame(data, schema, samplingRatio, numPartitions)
294302
}
295303

296304
#' @param ... additional argument(s).

R/pkg/R/context.R

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,16 @@ objectFile <- function(sc, path, minPartitions = NULL) {
9191
#' will write it to disk and send the file name to JVM. Also to make sure each slice is not
9292
#' larger than that limit, number of slices may be increased.
9393
#'
94+
#' In 2.2.0 we are changing how the numSlices are used/computed to handle
95+
#' 1 < (length(coll) / numSlices) << length(coll) better, and to get the exact number of slices.
96+
#' This change affects both createDataFrame and spark.lapply.
97+
#' In the specific one case that it is used to convert R native object into SparkDataFrame, it has
98+
#' always been kept at the default of 1. In the case the object is large, we are explicitly setting
99+
#' the parallism to numSlices (which is still 1).
100+
#'
101+
#' Specifically, we are changing to split positions to match the calculation in positions() of
102+
#' ParallelCollectionRDD in Spark.
103+
#'
94104
#' @param sc SparkContext to use
95105
#' @param coll collection to parallelize
96106
#' @param numSlices number of partitions to create in the RDD
@@ -107,6 +117,8 @@ parallelize <- function(sc, coll, numSlices = 1) {
107117
# TODO: bound/safeguard numSlices
108118
# TODO: unit tests for if the split works for all primitives
109119
# TODO: support matrix, data frame, etc
120+
121+
# Note, for data.frame, createDataFrame turns it into a list before it calls here.
110122
# nolint start
111123
# suppress lintr warning: Place a space before left parenthesis, except in a function call.
112124
if ((!is.list(coll) && !is.vector(coll)) || is.data.frame(coll)) {
@@ -128,12 +140,29 @@ parallelize <- function(sc, coll, numSlices = 1) {
128140
objectSize <- object.size(coll)
129141

130142
# For large objects we make sure the size of each slice is also smaller than sizeLimit
131-
numSlices <- max(numSlices, ceiling(objectSize / sizeLimit))
132-
if (numSlices > length(coll))
133-
numSlices <- length(coll)
143+
numSerializedSlices <- max(numSlices, ceiling(objectSize / sizeLimit))
144+
if (numSerializedSlices > length(coll))
145+
numSerializedSlices <- length(coll)
146+
147+
# Generate the slice ids to put each row
148+
# For instance, for numSerializedSlices of 22, length of 50
149+
# [1] 0 0 2 2 4 4 6 6 6 9 9 11 11 13 13 15 15 15 18 18 20 20 22 22 22
150+
# [26] 25 25 27 27 29 29 31 31 31 34 34 36 36 38 38 40 40 40 43 43 45 45 47 47 47
151+
# Notice the slice group with 3 slices (ie. 6, 15, 22) are roughly evenly spaced.
152+
# We are trying to reimplement the calculation in the positions method in ParallelCollectionRDD
153+
splits <- if (numSerializedSlices > 0) {
154+
unlist(lapply(0: (numSerializedSlices - 1), function(x) {
155+
# nolint start
156+
start <- trunc((x * length(coll)) / numSerializedSlices)
157+
end <- trunc(((x + 1) * length(coll)) / numSerializedSlices)
158+
# nolint end
159+
rep(start, end - start)
160+
}))
161+
} else {
162+
1
163+
}
134164

135-
sliceLen <- ceiling(length(coll) / numSlices)
136-
slices <- split(coll, rep(1: (numSlices + 1), each = sliceLen)[1:length(coll)])
165+
slices <- split(coll, splits)
137166

138167
# Serialize each slice: obtain a list of raws, or a list of lists (slices) of
139168
# 2-tuples of raws

R/pkg/inst/tests/testthat/test_rdd.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,8 @@ test_that("aggregateRDD() on RDDs", {
381381
test_that("zipWithUniqueId() on RDDs", {
382382
rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L)
383383
actual <- collectRDD(zipWithUniqueId(rdd))
384-
expected <- list(list("a", 0), list("b", 3), list("c", 1),
385-
list("d", 4), list("e", 2))
384+
expected <- list(list("a", 0), list("b", 1), list("c", 4),
385+
list("d", 2), list("e", 5))
386386
expect_equal(actual, expected)
387387

388388
rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L)

R/pkg/inst/tests/testthat/test_sparkSQL.R

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,26 @@ test_that("create DataFrame from RDD", {
196196
expect_equal(dtypes(df), list(c("name", "string"), c("age", "int"), c("height", "float")))
197197
expect_equal(as.list(collect(where(df, df$name == "John"))),
198198
list(name = "John", age = 19L, height = 176.5))
199+
expect_equal(getNumPartitions(toRDD(df)), 1)
200+
201+
df <- as.DataFrame(cars, numPartitions = 2)
202+
expect_equal(getNumPartitions(toRDD(df)), 2)
203+
df <- createDataFrame(cars, numPartitions = 3)
204+
expect_equal(getNumPartitions(toRDD(df)), 3)
205+
# validate limit by num of rows
206+
df <- createDataFrame(cars, numPartitions = 60)
207+
expect_equal(getNumPartitions(toRDD(df)), 50)
208+
# validate when 1 < (length(coll) / numSlices) << length(coll)
209+
df <- createDataFrame(cars, numPartitions = 20)
210+
expect_equal(getNumPartitions(toRDD(df)), 20)
211+
212+
df <- as.DataFrame(data.frame(0))
213+
expect_is(df, "SparkDataFrame")
214+
df <- createDataFrame(list(list(1)))
215+
expect_is(df, "SparkDataFrame")
216+
df <- as.DataFrame(data.frame(0), numPartitions = 2)
217+
# no data to partition, goes to 1
218+
expect_equal(getNumPartitions(toRDD(df)), 1)
199219

200220
setHiveContext(sc)
201221
sql("CREATE TABLE people (name string, age double, height float)")
@@ -213,7 +233,8 @@ test_that("createDataFrame uses files for large objects", {
213233
# To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value
214234
conf <- callJMethod(sparkSession, "conf")
215235
callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100")
216-
df <- suppressWarnings(createDataFrame(iris))
236+
df <- suppressWarnings(createDataFrame(iris, numPartitions = 3))
237+
expect_equal(getNumPartitions(toRDD(df)), 3)
217238

218239
# Resetting the conf back to default value
219240
callJMethod(conf, "set", "spark.r.maxAllocationLimit", toString(.Machine$integer.max / 10))

common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,190 @@ public UTF8String translate(Map<Character, Character> dict) {
835835
return fromString(sb.toString());
836836
}
837837

838+
private int getDigit(byte b) {
839+
if (b >= '0' && b <= '9') {
840+
return b - '0';
841+
}
842+
throw new NumberFormatException(toString());
843+
}
844+
845+
/**
846+
* Parses this UTF8String to long.
847+
*
848+
* Note that, in this method we accumulate the result in negative format, and convert it to
849+
* positive format at the end, if this string is not started with '-'. This is because min value
850+
* is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and
851+
* Integer.MIN_VALUE is '-2147483648'.
852+
*
853+
* This code is mostly copied from LazyLong.parseLong in Hive.
854+
*/
855+
public long toLong() {
856+
if (numBytes == 0) {
857+
throw new NumberFormatException("Empty string");
858+
}
859+
860+
byte b = getByte(0);
861+
final boolean negative = b == '-';
862+
int offset = 0;
863+
if (negative || b == '+') {
864+
offset++;
865+
if (numBytes == 1) {
866+
throw new NumberFormatException(toString());
867+
}
868+
}
869+
870+
final byte separator = '.';
871+
final int radix = 10;
872+
final long stopValue = Long.MIN_VALUE / radix;
873+
long result = 0;
874+
875+
while (offset < numBytes) {
876+
b = getByte(offset);
877+
offset++;
878+
if (b == separator) {
879+
// We allow decimals and will return a truncated integral in that case.
880+
// Therefore we won't throw an exception here (checking the fractional
881+
// part happens below.)
882+
break;
883+
}
884+
885+
int digit = getDigit(b);
886+
// We are going to process the new digit and accumulate the result. However, before doing
887+
// this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then
888+
// result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
889+
if (result < stopValue) {
890+
throw new NumberFormatException(toString());
891+
}
892+
893+
result = result * radix - digit;
894+
// Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we
895+
// can just use `result > 0` to check overflow. If result overflows, we should stop and throw
896+
// exception.
897+
if (result > 0) {
898+
throw new NumberFormatException(toString());
899+
}
900+
}
901+
902+
// This is the case when we've encountered a decimal separator. The fractional
903+
// part will not change the number, but we will verify that the fractional part
904+
// is well formed.
905+
while (offset < numBytes) {
906+
if (getDigit(getByte(offset)) == -1) {
907+
throw new NumberFormatException(toString());
908+
}
909+
offset++;
910+
}
911+
912+
if (!negative) {
913+
result = -result;
914+
if (result < 0) {
915+
throw new NumberFormatException(toString());
916+
}
917+
}
918+
919+
return result;
920+
}
921+
922+
/**
923+
* Parses this UTF8String to int.
924+
*
925+
* Note that, in this method we accumulate the result in negative format, and convert it to
926+
* positive format at the end, if this string is not started with '-'. This is because min value
927+
* is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and
928+
* Integer.MIN_VALUE is '-2147483648'.
929+
*
930+
* This code is mostly copied from LazyInt.parseInt in Hive.
931+
*
932+
* Note that, this method is almost same as `toLong`, but we leave it duplicated for performance
933+
* reasons, like Hive does.
934+
*/
935+
public int toInt() {
936+
if (numBytes == 0) {
937+
throw new NumberFormatException("Empty string");
938+
}
939+
940+
byte b = getByte(0);
941+
final boolean negative = b == '-';
942+
int offset = 0;
943+
if (negative || b == '+') {
944+
offset++;
945+
if (numBytes == 1) {
946+
throw new NumberFormatException(toString());
947+
}
948+
}
949+
950+
final byte separator = '.';
951+
final int radix = 10;
952+
final int stopValue = Integer.MIN_VALUE / radix;
953+
int result = 0;
954+
955+
while (offset < numBytes) {
956+
b = getByte(offset);
957+
offset++;
958+
if (b == separator) {
959+
// We allow decimals and will return a truncated integral in that case.
960+
// Therefore we won't throw an exception here (checking the fractional
961+
// part happens below.)
962+
break;
963+
}
964+
965+
int digit = getDigit(b);
966+
// We are going to process the new digit and accumulate the result. However, before doing
967+
// this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then
968+
// result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
969+
if (result < stopValue) {
970+
throw new NumberFormatException(toString());
971+
}
972+
973+
result = result * radix - digit;
974+
// Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix),
975+
// we can just use `result > 0` to check overflow. If result overflows, we should stop and
976+
// throw exception.
977+
if (result > 0) {
978+
throw new NumberFormatException(toString());
979+
}
980+
}
981+
982+
// This is the case when we've encountered a decimal separator. The fractional
983+
// part will not change the number, but we will verify that the fractional part
984+
// is well formed.
985+
while (offset < numBytes) {
986+
if (getDigit(getByte(offset)) == -1) {
987+
throw new NumberFormatException(toString());
988+
}
989+
offset++;
990+
}
991+
992+
if (!negative) {
993+
result = -result;
994+
if (result < 0) {
995+
throw new NumberFormatException(toString());
996+
}
997+
}
998+
999+
return result;
1000+
}
1001+
1002+
public short toShort() {
1003+
int intValue = toInt();
1004+
short result = (short) intValue;
1005+
if (result != intValue) {
1006+
throw new NumberFormatException(toString());
1007+
}
1008+
1009+
return result;
1010+
}
1011+
1012+
public byte toByte() {
1013+
int intValue = toInt();
1014+
byte result = (byte) intValue;
1015+
if (result != intValue) {
1016+
throw new NumberFormatException(toString());
1017+
}
1018+
1019+
return result;
1020+
}
1021+
8381022
@Override
8391023
public String toString() {
8401024
return new String(getBytes(), StandardCharsets.UTF_8);

python/pyspark/sql/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(self, sparkContext, sparkSession=None, jsqlContext=None):
7373
self._jsc = self._sc._jsc
7474
self._jvm = self._sc._jvm
7575
if sparkSession is None:
76-
sparkSession = SparkSession(sparkContext)
76+
sparkSession = SparkSession.builder.getOrCreate()
7777
if jsqlContext is None:
7878
jsqlContext = sparkSession._jwrapped
7979
self.sparkSession = sparkSession

python/pyspark/sql/tests.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
import unittest
4848

4949
from pyspark import SparkContext
50-
from pyspark.sql import SparkSession, HiveContext, Column, Row
50+
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
5151
from pyspark.sql.types import *
5252
from pyspark.sql.types import UserDefinedType, _infer_type
5353
from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests
@@ -206,6 +206,11 @@ def tearDownClass(cls):
206206
cls.spark.stop()
207207
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
208208

209+
def test_sqlcontext_reuses_sparksession(self):
210+
sqlContext1 = SQLContext(self.sc)
211+
sqlContext2 = SQLContext(self.sc)
212+
self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession)
213+
209214
def test_row_should_be_read_only(self):
210215
row = Row(a=1, b=2)
211216
self.assertEqual(1, row.a)

0 commit comments

Comments
 (0)