@@ -18,6 +18,8 @@ package org.apache.spark.sql.kinesis
18
18
19
19
import java .nio .ByteBuffer
20
20
21
+ import scala .util .Try
22
+
21
23
import com .amazonaws .services .kinesis .producer .{KinesisProducer , UserRecordResult }
22
24
import com .google .common .util .concurrent .{FutureCallback , Futures }
23
25
@@ -34,9 +36,19 @@ private[kinesis] class KinesisWriteTask(producerConfiguration: Map[String, Strin
34
36
private val streamName = producerConfiguration.getOrElse(
35
37
KinesisSourceProvider .SINK_STREAM_NAME_KEY , " " )
36
38
39
+ private val flushWaitTimeMills = Try (producerConfiguration.getOrElse(
40
+ KinesisSourceProvider .SINK_FLUSH_WAIT_TIME_MILLIS ,
41
+ KinesisSourceProvider .DEFAULT_FLUSH_WAIT_TIME_MILLIS ).toLong).getOrElse {
42
+ throw new IllegalArgumentException (
43
+ s " ${KinesisSourceProvider .SINK_FLUSH_WAIT_TIME_MILLIS } has to be a positive integer " )
44
+ }
45
+
46
+ private var failedWrite : Throwable = _
47
+
48
+
37
49
def execute (iterator : Iterator [InternalRow ]): Unit = {
38
50
producer = CachedKinesisProducer .getOrCreate(producerConfiguration)
39
- while (iterator.hasNext) {
51
+ while (iterator.hasNext && failedWrite == null ) {
40
52
val currentRow = iterator.next()
41
53
val projectedRow = projection(currentRow)
42
54
val partitionKey = projectedRow.getString(0 )
@@ -54,7 +66,10 @@ private[kinesis] class KinesisWriteTask(producerConfiguration: Map[String, Strin
54
66
val kinesisCallBack = new FutureCallback [UserRecordResult ]() {
55
67
56
68
override def onFailure (t : Throwable ): Unit = {
57
- logError(s " Writing to $streamName failed due to ${t.getCause}" )
69
+ if (failedWrite == null && t!= null ) {
70
+ failedWrite = t
71
+ logError(s " Writing to $streamName failed due to ${t.getCause}" )
72
+ }
58
73
}
59
74
60
75
override def onSuccess (result : UserRecordResult ): Unit = {
@@ -68,13 +83,34 @@ private[kinesis] class KinesisWriteTask(producerConfiguration: Map[String, Strin
68
83
sentSeqNumbers
69
84
}
70
85
71
- def close (): Unit = {
86
+ private def flushRecordsIfNecessary (): Unit = {
72
87
if (producer != null ) {
73
- producer.flush()
74
- producer = null
88
+ while (producer.getOutstandingRecordsCount > 0 ) {
89
+ try {
90
+ producer.flush()
91
+ Thread .sleep(flushWaitTimeMills)
92
+ checkForErrors()
93
+ } catch {
94
+ case e : InterruptedException =>
95
+
96
+ }
97
+ }
75
98
}
76
99
}
77
100
101
+ def checkForErrors (): Unit = {
102
+ if (failedWrite != null ) {
103
+ throw failedWrite
104
+ }
105
+ }
106
+
107
+ def close (): Unit = {
108
+ checkForErrors()
109
+ flushRecordsIfNecessary()
110
+ checkForErrors()
111
+ producer = null
112
+ }
113
+
78
114
private def createProjection : UnsafeProjection = {
79
115
80
116
val partitionKeyExpression = inputSchema
0 commit comments