Skip to content

Commit 38ac927

Browse files
committed
[SPARK-47172] Addressing reviewer comments
1 parent 5f8cd7f commit 38ac927

File tree

1 file changed

+59
-40
lines changed

1 file changed

+59
-40
lines changed

common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java

Lines changed: 59 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -221,27 +221,24 @@ public long transferTo(WritableByteChannel target, long position) throws IOExcep
221221
plaintextBuffer.limit(readLimit);
222222
if (plaintextMessage instanceof ByteBuf byteBuf) {
223223
byteBuf.readBytes(plaintextBuffer);
224-
long inputBytesRead = readableBytes - byteBuf.readableBytes();
225-
bytesRead += inputBytesRead;
226224
} else if (plaintextMessage instanceof AbstractFileRegion fileRegion) {
227225
ByteBufferWriteableChannel plaintextChannel =
228226
new ByteBufferWriteableChannel(plaintextBuffer);
229227
long transferred =
230228
fileRegion.transferTo(plaintextChannel, fileRegion.transferred());
231-
bytesRead += transferred;
232-
if (transferred == 0) {
233-
// File regions may return 0 if they are not ready to transfer
234-
// more data. In that case, we'll return with the expectation
235-
// that this transferTo() is called again.
229+
if (transferred < readLimit) {
230+
// If we do not read a full plaintext buffer or all the available readable bytes,
231+
// return what was transferred this call.
236232
return transferredThisCall;
237233
}
238234
}
239235
plaintextBuffer.flip();
236+
bytesRead += plaintextBuffer.remaining();
240237
ciphertextBuffer.clear();
241238
try {
242239
encrypter.encryptSegment(plaintextBuffer, lastSegment, ciphertextBuffer);
243240
} catch (GeneralSecurityException e) {
244-
throw new RuntimeException(e);
241+
throw new IllegalStateException("GeneralSecurityException from encrypter", e);
245242
}
246243
ciphertextBuffer.flip();
247244
int written = target.write(ciphertextBuffer);
@@ -279,24 +276,64 @@ protected void deallocate() {
279276

280277
@VisibleForTesting
281278
class DecryptionHandler extends ChannelInboundHandlerAdapter {
279+
private final ByteBuffer expectedLengthBuffer;
280+
private final ByteBuffer headerBuffer;
282281
private final ByteBuffer ciphertextBuffer;
283282
private final ByteBuffer plaintextBuffer;
284283
private final AesGcmHkdfStreaming aesGcmHkdfStreaming;
285284
private final StreamSegmentDecrypter decrypter;
286285
private boolean decrypterInit = false;
286+
private boolean lastSegment = false;
287287
private int segmentNumber = 0;
288288
private long expectedLength = -1;
289289
private long ciphertextRead = 0;
290290

291291
DecryptionHandler() throws GeneralSecurityException {
292292
aesGcmHkdfStreaming = getAesGcmHkdfStreaming();
293+
expectedLengthBuffer = ByteBuffer.allocate(LENGTH_HEADER_BYTES);
294+
headerBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getHeaderLength());
293295
plaintextBuffer =
294296
ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize());
295297
ciphertextBuffer =
296298
ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize());
297299
decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter();
298300
}
299301

302+
private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) {
303+
if (expectedLength < 0 && expectedLengthBuffer.hasRemaining()) {
304+
ciphertextNettyBuf.readBytes(expectedLengthBuffer);
305+
if (expectedLengthBuffer.hasRemaining()) {
306+
// We did not read enough bytes to initialize the expected length.
307+
return false;
308+
}
309+
expectedLengthBuffer.flip();
310+
expectedLength = expectedLengthBuffer.getLong();
311+
if (expectedLength < 0) {
312+
throw new IllegalStateException("Invalid expected ciphertext length.");
313+
}
314+
ciphertextRead += LENGTH_HEADER_BYTES;
315+
}
316+
return true;
317+
}
318+
319+
private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf) throws GeneralSecurityException {
320+
// Check if the ciphertext header has been read. This contains
321+
// the IV and other internal metadata.
322+
if (!decrypterInit && headerBuffer.hasRemaining()) {
323+
ciphertextNettyBuf.readBytes(headerBuffer);
324+
if (headerBuffer.hasRemaining()) {
325+
// We did not read enough bytes to initialize the header.
326+
return false;
327+
}
328+
headerBuffer.flip();
329+
byte[] lengthAad = Longs.toByteArray(expectedLength);
330+
decrypter.init(headerBuffer, lengthAad);
331+
decrypterInit = true;
332+
ciphertextRead += aesGcmHkdfStreaming.getHeaderLength();
333+
}
334+
return true;
335+
}
336+
300337
@Override
301338
public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
302339
throws GeneralSecurityException {
@@ -307,37 +344,20 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
307344
// The format of the output is:
308345
// [8 byte length][Internal IV and header][Ciphertext][Auth Tag]
309346
try {
310-
while (ciphertextNettyBuf.readableBytes() > 0) {
311-
// Check if the expected ciphertext length has been read.
312-
if (expectedLength < 0 &&
313-
ciphertextNettyBuf.readableBytes() >= LENGTH_HEADER_BYTES) {
314-
expectedLength = ciphertextNettyBuf.readLong();
315-
if (expectedLength < 0) {
316-
throw new IllegalStateException("Invalid expected ciphertext length.");
317-
}
318-
ciphertextRead += LENGTH_HEADER_BYTES;
319-
}
320-
int headerLength = aesGcmHkdfStreaming.getHeaderLength();
321-
// Check if the ciphertext header has been read. This contains
322-
// the IV and other internal metadata.
323-
if (!decrypterInit &&
324-
ciphertextNettyBuf.readableBytes() >= headerLength) {
325-
ByteBuffer headerBuffer = ByteBuffer.allocate(headerLength);
326-
ciphertextNettyBuf.readBytes(headerBuffer);
327-
headerBuffer.flip();
328-
byte[] lengthAad = Longs.toByteArray(expectedLength);
329-
decrypter.init(headerBuffer, lengthAad);
330-
decrypterInit = true;
331-
ciphertextRead += headerLength;
332-
}
333-
// This may occur if there weren't enough readable bytes to read the header.
334-
if (!decrypterInit) {
335-
return;
336-
}
337-
// This may occur if the expected length is just the header.
338-
if (expectedLength == ciphertextRead) {
339-
return;
340-
}
347+
if (!initalizeExpectedLength(ciphertextNettyBuf)) {
348+
// We have not read enough bytes to initialize the expected length.
349+
return;
350+
}
351+
if (!initalizeDecrypter(ciphertextNettyBuf)) {
352+
// We nave not read enough bytes to initalize a header, needed to
353+
// initialize a decrypter.
354+
return;
355+
}
356+
if (expectedLength == ciphertextRead) {
357+
// If the expected length is just the header, the ciphertext is 0 length.
358+
lastSegment = true;
359+
}
360+
while (ciphertextNettyBuf.readableBytes() > 0 && !lastSegment) {
341361
ciphertextBuffer.clear();
342362
// Read the ciphertext into the local buffer
343363
int readableBytes = Integer.min(
@@ -353,7 +373,6 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage)
353373
ciphertextNettyBuf.readBytes(ciphertextBuffer);
354374
ciphertextRead += bytesToRead;
355375
// Check if this is the last segment
356-
boolean lastSegment = false;
357376
if (ciphertextRead == expectedLength) {
358377
lastSegment = true;
359378
} else if (ciphertextRead > expectedLength) {

0 commit comments

Comments
 (0)