Skip to content

Commit 374de3c

Browse files
authored
Merge pull request #113 from RedisLabs/dataframe
dataframe enhancements
2 parents 78ff03d + 117c65a commit 374de3c

18 files changed

+262
-73
lines changed

doc/dataframe.md

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -332,15 +332,17 @@ root
332332

333333
## DataFrame options
334334

335-
| Name | Description | Type | Default |
336-
| ----------------- | ------------------------------------------------------------------------------------------| --------------------- | ------- |
337-
| model | defines Redis model used to persist DataFrame, see [Persistence model](#persistence-model)| `enum [binary, hash]` | `hash` |
338-
| partitions.number | number of partitions (applies only when reading dataframe) | `Int` | `3` |
339-
| key.column | when writing - specifies unique column used as a Redis key, by default a key is auto-generated. <br/> When reading - specifies column name to store hash key | `String` | - |
340-
| ttl | data time to live in `seconds`. Data doesn't expire if `ttl` is less than `1` | `Int` | `0` |
341-
| infer.schema | infer schema from random row, all columns will have `String` type | `Boolean` | `false` |
342-
| max.pipeline.size | maximum number of commands per pipeline (used to batch commands) | `Int` | 100 |
343-
| scan.count | count option of SCAN command (used to iterate over keys) | `Int` | 100 |
335+
| Name | Description | Type | Default |
336+
| -----------------------| ------------------------------------------------------------------------------------------| --------------------- | ------- |
337+
| model | defines Redis model used to persist DataFrame, see [Persistence model](#persistence-model)| `enum [binary, hash]` | `hash` |
338+
| filter.keys.by.type | make sure the underlying data structures match persistence model | `Boolean` | `false` |
339+
| partitions.number | number of partitions (applies only when reading dataframe) | `Int` | `3` |
340+
| key.column | when writing - specifies unique column used as a Redis key, by default a key is auto-generated. <br/> When reading - specifies column name to store hash key | `String` | - |
341+
| ttl | data time to live in `seconds`. Data doesn't expire if `ttl` is less than `1` | `Int` | `0` |
342+
| infer.schema | infer schema from random row, all columns will have `String` type | `Boolean` | `false` |
343+
| max.pipeline.size | maximum number of commands per pipeline (used to batch commands) | `Int` | 100 |
344+
| scan.count | count option of SCAN command (used to iterate over keys) | `Int` | 100 |
345+
| iterator.grouping.size | the number of items to be grouped when iterating over underlying RDD partition | `Int` | 1000 |
344346

345347

346348
## Known limitations

src/main/scala/com/redislabs/provider/redis/RedisConfig.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ case class RedisEndpoint(host: String = Protocol.DEFAULT_HOST,
6666
def connect(): Jedis = {
6767
ConnectionPool.connect(this)
6868
}
69+
70+
/**
71+
* @return config with masked password. Used for logging.
72+
*/
73+
def maskPassword(): RedisEndpoint = {
74+
this.copy(auth = "")
75+
}
6976
}
7077

7178
case class RedisNode(endpoint: RedisEndpoint,
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
package com.redislabs.provider
22

3-
package object redis extends RedisFunctions
3+
package object redis extends RedisFunctions {
4+
5+
val RedisDataTypeHash: String = "hash"
6+
val RedisDataTypeString: String = "string"
7+
}

src/main/scala/com/redislabs/provider/redis/rdd/RedisRDD.scala

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class RedisKVRDD(prev: RDD[String],
4545
val res = stringKeys.zip(response).iterator.asInstanceOf[Iterator[(String, String)]]
4646
conn.close()
4747
res
48-
}.iterator
48+
}
4949
}
5050

5151
def getHASH(nodes: Array[RedisNode], keys: Iterator[String]): Iterator[(String, String)] = {
@@ -55,7 +55,7 @@ class RedisKVRDD(prev: RDD[String],
5555
val res = hashKeys.flatMap(conn.hgetAll).iterator
5656
conn.close()
5757
res
58-
}.iterator
58+
}
5959
}
6060
}
6161

@@ -84,7 +84,7 @@ class RedisListRDD(prev: RDD[String],
8484
val res = setKeys.flatMap(conn.smembers).iterator
8585
conn.close()
8686
res
87-
}.iterator
87+
}
8888
}
8989

9090
def getLIST(nodes: Array[RedisNode], keys: Iterator[String]): Iterator[String] = {
@@ -94,7 +94,7 @@ class RedisListRDD(prev: RDD[String],
9494
val res = listKeys.flatMap(conn.lrange(_, 0, -1)).iterator
9595
conn.close()
9696
res
97-
}.iterator
97+
}
9898
}
9999
}
100100

@@ -146,7 +146,7 @@ class RedisZSetRDD[T: ClassTag](prev: RDD[String],
146146
}
147147
conn.close()
148148
res
149-
}.iterator.asInstanceOf[Iterator[T]]
149+
}.asInstanceOf[Iterator[T]]
150150
}
151151

152152
private def getZSetByScore(nodes: Array[RedisNode],
@@ -168,7 +168,7 @@ class RedisZSetRDD[T: ClassTag](prev: RDD[String],
168168
}
169169
conn.close()
170170
res
171-
}.iterator.asInstanceOf[Iterator[T]]
171+
}.asInstanceOf[Iterator[T]]
172172
}
173173
}
174174

@@ -255,7 +255,11 @@ class RedisKeysRDD(sc: SparkContext,
255255
slot >= sPos && slot <= ePos
256256
}).iterator
257257
} else {
258-
getKeys(nodes, sPos, ePos, keyPattern).iterator
258+
logInfo {
259+
val nodesPassMasked = nodes.map(n => n.copy(endpoint = n.endpoint.maskPassword())).mkString
260+
s"Computing partition, get keys partId: ${partition.index}, [$sPos - $ePos] nodes: $nodesPassMasked"
261+
}
262+
getKeys(nodes, sPos, ePos, keyPattern)
259263
}
260264
}
261265

@@ -392,12 +396,14 @@ trait Keys {
392396
}
393397

394398
/**
399+
* Scan keys, the result may contain duplicates
400+
*
395401
* @param jedis
396402
* @param params
397403
* @return keys of params pattern in jedis
398404
*/
399-
private def scanKeys(jedis: Jedis, params: ScanParams): util.HashSet[String] = {
400-
val keys = new util.HashSet[String]
405+
private def scanKeys(jedis: Jedis, params: ScanParams): util.List[String] = {
406+
val keys = new util.ArrayList[String]
401407
var cursor = "0"
402408
do {
403409
val scan = jedis.scan(cursor, params)
@@ -418,24 +424,25 @@ trait Keys {
418424
sPos: Int,
419425
ePos: Int,
420426
keyPattern: String)
421-
(implicit readWriteConfig: ReadWriteConfig): util.HashSet[String] = {
422-
val keys = new util.HashSet[String]()
427+
(implicit readWriteConfig: ReadWriteConfig): Iterator[String] = {
428+
val endpoints = nodes.map(_.endpoint).distinct
429+
423430
if (isRedisRegex(keyPattern)) {
424-
nodes.foreach { node =>
425-
val conn = node.endpoint.connect()
431+
endpoints.iterator.map { endpoint =>
432+
val keys = new util.HashSet[String]()
433+
val conn = endpoint.connect()
426434
val params = new ScanParams().`match`(keyPattern).count(readWriteConfig.scanCount)
427-
val res = keys.addAll(scanKeys(conn, params).filter { key =>
435+
keys.addAll(scanKeys(conn, params).filter { key =>
428436
val slot = JedisClusterCRC16.getSlot(key)
429437
slot >= sPos && slot <= ePos
430438
})
431439
conn.close()
432-
res
433-
}
440+
keys.iterator()
441+
}.flatten
434442
} else {
435443
val slot = JedisClusterCRC16.getSlot(keyPattern)
436-
if (slot >= sPos && slot <= ePos) keys.add(keyPattern)
444+
if (slot >= sPos && slot <= ePos) Iterator(keyPattern) else Iterator()
437445
}
438-
keys
439446
}
440447

441448
/**
@@ -456,9 +463,9 @@ trait Keys {
456463
* @param keys list of keys
457464
* @return (node: (key1, key2, ...), node2: (key3, key4,...), ...)
458465
*/
459-
def groupKeysByNode(nodes: Array[RedisNode], keys: Iterator[String]): Array[(RedisNode, Array[String])] = {
466+
def groupKeysByNode(nodes: Array[RedisNode], keys: Iterator[String]): Iterator[(RedisNode, Array[String])] = {
460467
keys.map(key => (getMasterNode(nodes, key), key)).toArray.groupBy(_._1).
461-
map(x => (x._1, x._2.map(_._2))).toArray
468+
map(x => (x._1, x._2.map(_._2))).iterator
462469
}
463470

464471
/**

src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import com.redislabs.provider.redis.rdd.Keys
66
import com.redislabs.provider.redis.util.ConnectionUtils.withConnection
77
import com.redislabs.provider.redis.util.Logging
88
import com.redislabs.provider.redis.util.PipelineUtils._
9-
import com.redislabs.provider.redis.{ReadWriteConfig, RedisConfig, RedisEndpoint, RedisNode, toRedisContext}
9+
import com.redislabs.provider.redis.{ReadWriteConfig, RedisConfig, RedisDataTypeHash, RedisDataTypeString, RedisEndpoint, RedisNode, toRedisContext}
1010
import org.apache.commons.lang3.SerializationUtils
1111
import org.apache.spark.rdd.RDD
1212
import org.apache.spark.sql.catalyst.expressions.GenericRow
@@ -62,16 +62,19 @@ class RedisSourceRelation(override val sqlContext: SQLContext,
6262
*/
6363
@volatile private var currentSchema: StructType = _
6464

65-
/** parameters **/
66-
private val tableNameOpt: Option[String] = parameters.get(SqlOptionTableName)
67-
private val keysPatternOpt: Option[String] = parameters.get(SqlOptionKeysPattern)
65+
/** parameters (sorted alphabetically) **/
66+
private val filterKeysByTypeEnabled = parameters.get(SqlOptionFilterKeysByType).exists(_.toBoolean)
67+
private val inferSchemaEnabled = parameters.get(SqlOptionInferSchema).exists(_.toBoolean)
68+
private val iteratorGroupingSize = parameters.get(SqlOptionIteratorGroupingSize).map(_.toInt)
69+
.getOrElse(SqlOptionIteratorGroupingSizeDefault)
6870
private val keyColumn = parameters.get(SqlOptionKeyColumn)
6971
private val keyName = keyColumn.getOrElse("_id")
72+
private val keysPatternOpt: Option[String] = parameters.get(SqlOptionKeysPattern)
7073
private val numPartitions = parameters.get(SqlOptionNumPartitions).map(_.toInt)
7174
.getOrElse(SqlOptionNumPartitionsDefault)
72-
private val inferSchemaEnabled = parameters.get(SqlOptionInferSchema).exists(_.toBoolean)
7375
private val persistenceModel = parameters.getOrDefault(SqlOptionModel, SqlOptionModelHash)
7476
private val persistence = RedisPersistence(persistenceModel)
77+
private val tableNameOpt: Option[String] = parameters.get(SqlOptionTableName)
7578
private val ttl = parameters.get(SqlOptionTTL).map(_.toInt).getOrElse(0)
7679

7780
/**
@@ -129,15 +132,19 @@ class RedisSourceRelation(override val sqlContext: SQLContext,
129132

130133
// write data
131134
data.foreachPartition { partition =>
132-
val rowsWithKey: Map[String, Row] = partition.map(row => dataKeyId(row) -> row).toMap
133-
groupKeysByNode(redisConfig.hosts, rowsWithKey.keysIterator).foreach { case (node, keys) =>
134-
val conn = node.connect()
135-
foreachWithPipeline(conn, keys) { (pipeline, key) =>
136-
val row = rowsWithKey(key)
137-
val encodedRow = persistence.encodeRow(keyName, row)
138-
persistence.save(pipeline, key, encodedRow, ttl)
135+
// grouped iterator to only allocate memory for a portion of rows
136+
partition.grouped(iteratorGroupingSize).foreach { batch =>
137+
// the following can be optimized to not create a map
138+
val rowsWithKey: Map[String, Row] = batch.map(row => dataKeyId(row) -> row).toMap
139+
groupKeysByNode(redisConfig.hosts, rowsWithKey.keysIterator).foreach { case (node, keys) =>
140+
val conn = node.connect()
141+
foreachWithPipeline(conn, keys) { (pipeline, key) =>
142+
val row = rowsWithKey(key)
143+
val encodedRow = persistence.encodeRow(keyName, row)
144+
persistence.save(pipeline, key, encodedRow, ttl)
145+
}
146+
conn.close()
139147
}
140-
conn.close()
141148
}
142149
}
143150
}
@@ -158,24 +165,31 @@ class RedisSourceRelation(override val sqlContext: SQLContext,
158165
}
159166
StructType(filteredFields)
160167
}
168+
val keyType =
169+
if (persistenceModel == SqlOptionModelBinary) {
170+
RedisDataTypeString
171+
} else {
172+
RedisDataTypeHash
173+
}
161174
keysRdd.mapPartitions { partition =>
162-
groupKeysByNode(redisConfig.hosts, partition)
163-
.flatMap { case (node, keys) =>
164-
scanRows(node, keys, filteredSchema, requiredColumns)
165-
}
166-
.iterator
175+
// grouped iterator to only allocate memory for a portion of rows
176+
partition.grouped(iteratorGroupingSize).map { batch =>
177+
groupKeysByNode(redisConfig.hosts, batch.iterator)
178+
.flatMap { case (node, keys) =>
179+
scanRows(node, keys, keyType, filteredSchema, requiredColumns)
180+
}
181+
}.flatten
167182
}
168183
}
169184
}
170185

171-
172186
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = filters
173187

174188
/**
175189
* @return true if data exists in redis
176190
*/
177191
def isEmpty: Boolean = {
178-
sc.fromRedisKeyPattern(dataKeyPattern).isEmpty()
192+
sc.fromRedisKeyPattern(dataKeyPattern, partitionNum = numPartitions).isEmpty()
179193
}
180194

181195
/**
@@ -257,13 +271,22 @@ class RedisSourceRelation(override val sqlContext: SQLContext,
257271
/**
258272
* read rows from redis
259273
*/
260-
private def scanRows(node: RedisNode, keys: Seq[String], schema: StructType,
274+
private def scanRows(node: RedisNode, keys: Seq[String], keyType: String, schema: StructType,
261275
requiredColumns: Seq[String]): Seq[Row] = {
262276
withConnection(node.connect()) { conn =>
263-
val pipelineValues = mapWithPipeline(conn, keys) { (pipeline, key) =>
277+
val filteredKeys =
278+
if (filterKeysByTypeEnabled) {
279+
val keyTypes = mapWithPipeline(conn, keys) { (pipeline, key) =>
280+
pipeline.`type`(key)
281+
}
282+
keys.zip(keyTypes).filter(_._2 == keyType).map(_._1)
283+
} else {
284+
keys
285+
}
286+
val pipelineValues = mapWithPipeline(conn, filteredKeys) { (pipeline, key) =>
264287
persistence.load(pipeline, key, requiredColumns)
265288
}
266-
keys.zip(pipelineValues).map { case (key, value) =>
289+
filteredKeys.zip(pipelineValues).map { case (key, value) =>
267290
val keyMap = keyName -> tableKey(keysPrefixPattern, key)
268291
persistence.decodeRow(keyMap, value, schema, requiredColumns)
269292
}

src/main/scala/org/apache/spark/sql/redis/redis.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ package org.apache.spark.sql
66
package object redis {
77

88
val RedisFormat = "org.apache.spark.sql.redis"
9+
10+
val SqlOptionFilterKeysByType = "filter.keys.by.type"
911
val SqlOptionNumPartitions = "partitions.number"
1012
/**
1113
* Default read operation number of partitions.
@@ -22,4 +24,7 @@ package object redis {
2224

2325
val SqlOptionMaxPipelineSize = "max.pipeline.size"
2426
val SqlOptionScanCount = "scan.count"
27+
28+
val SqlOptionIteratorGroupingSize = "iterator.grouping.size"
29+
val SqlOptionIteratorGroupingSizeDefault = 1000
2530
}

src/test/scala/com/redislabs/provider/redis/df/BinaryDataframeSuite.scala

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
package com.redislabs.provider.redis.df
22

3+
import com.redislabs.provider.redis.toRedisContext
4+
import com.redislabs.provider.redis.util.Person
35
import com.redislabs.provider.redis.util.Person._
6+
import com.redislabs.provider.redis.util.TestUtils._
7+
import org.apache.commons.lang3.SerializationUtils
48
import org.apache.spark.SparkException
9+
import org.apache.spark.sql.redis.RedisSourceRelation.tableDataKeyPattern
510
import org.apache.spark.sql.redis._
611
import org.scalatest.Matchers
7-
import com.redislabs.provider.redis.util.TestUtils._
812

913
/**
1014
* @author The Viet Nguyen
@@ -55,4 +59,53 @@ trait BinaryDataframeSuite extends RedisDataframeSuite with Matchers {
5559
.show()
5660
}
5761
}
62+
63+
test("load filtered hash keys with strings") {
64+
val tableName = generateTableName(TableNamePrefix)
65+
val df = spark.createDataFrame(data)
66+
df.write.format(RedisFormat)
67+
.option(SqlOptionTableName, tableName)
68+
.option(SqlOptionModel, SqlOptionModelHash)
69+
.save()
70+
val extraKey = RedisSourceRelation.uuid()
71+
saveMap(tableName, extraKey, Person.dataMaps.head)
72+
val loadedIds = spark.read.format(RedisFormat)
73+
.schema(Person.fullSchema)
74+
.option(SqlOptionTableName, tableName)
75+
.option(SqlOptionModel, SqlOptionModelHash)
76+
.option(SqlOptionFilterKeysByType, value = true)
77+
.load()
78+
.collect()
79+
.map { r =>
80+
r.getAs[String]("_id")
81+
}
82+
loadedIds.length shouldBe 2
83+
loadedIds should not contain extraKey
84+
val countAll = sc.fromRedisKeyPattern(tableDataKeyPattern(tableName)).count()
85+
countAll shouldBe 3
86+
}
87+
88+
test("load unfiltered hash keys with strings") {
89+
val tableName = generateTableName(TableNamePrefix)
90+
val df = spark.createDataFrame(data)
91+
df.write.format(RedisFormat)
92+
.option(SqlOptionTableName, tableName)
93+
.option(SqlOptionModel, SqlOptionModelHash)
94+
.save()
95+
saveMap(tableName, RedisSourceRelation.uuid(), Person.dataMaps.head)
96+
intercept[SparkException] {
97+
spark.read.format(RedisFormat)
98+
.option(SqlOptionTableName, tableName)
99+
.option(SqlOptionModel, SqlOptionModelHash)
100+
.load()
101+
.collect()
102+
}
103+
}
104+
105+
def serialize(value: Map[String, String]): Array[Byte] = {
106+
val valuesArray = value.values.toArray
107+
SerializationUtils.serialize(valuesArray)
108+
}
109+
110+
def saveMap(tableName: String, key: String, value: Map[String, String]): Unit
58111
}

0 commit comments

Comments
 (0)