Skip to content

Commit

Permalink
Add support for encrypted async blob read (opensearch-project#10131)
Browse files Browse the repository at this point in the history
* Add support for encrypted async blob read

Signed-off-by: Kunal Kotwani <kkotwani@amazon.com>

* Add async blob read support for encrypted containers

Signed-off-by: Kunal Kotwani <kkotwani@amazon.com>

---------

Signed-off-by: Kunal Kotwani <kkotwani@amazon.com>
  • Loading branch information
kotwanikunal authored and sarthakaggarwal97 committed Sep 24, 2023
1 parent cb0e554 commit 814edb5
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add metrics for thread_pool task wait time ([#9681](https://github.com/opensearch-project/OpenSearch/pull/9681))
- Async blob read support for S3 plugin ([#9694](https://github.com/opensearch-project/OpenSearch/pull/9694))
- [Telemetry-Otel] Added support for OtlpGrpcSpanExporter exporter ([#9666](https://github.com/opensearch-project/OpenSearch/pull/9666))
- Async blob read support for encrypted containers ([#10131](https://github.com/opensearch-project/OpenSearch/pull/10131))

### Dependencies
- Bump `peter-evans/create-or-update-comment` from 2 to 3 ([#9575](https://github.com/opensearch-project/OpenSearch/pull/9575))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import org.opensearch.common.blobstore.stream.read.ReadContext;
import org.opensearch.common.blobstore.stream.write.WriteContext;
import org.opensearch.common.crypto.CryptoHandler;
import org.opensearch.common.crypto.DecryptedRangedStreamProvider;
import org.opensearch.common.io.InputStreamContainer;
import org.opensearch.core.action.ActionListener;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
import java.nio.file.Path;
import java.io.InputStream;
import java.util.List;
import java.util.stream.Collectors;

/**
* EncryptedBlobContainer is an encrypted BlobContainer that is backed by a
Expand All @@ -44,12 +46,17 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener<Void> comp

@Override
public void readBlobAsync(String blobName, ActionListener<ReadContext> listener) {
throw new UnsupportedOperationException();
}

@Override
public void asyncBlobDownload(String blobName, Path fileLocation, ThreadPool threadPool, ActionListener<String> completionListener) {
throw new UnsupportedOperationException();
try {
final U cryptoContext = cryptoHandler.loadEncryptionMetadata(getEncryptedHeaderContentSupplier(blobName));
ActionListener<ReadContext> decryptingCompletionListener = ActionListener.map(
listener,
readContext -> new DecryptedReadContext<>(readContext, cryptoHandler, cryptoContext)
);

blobContainer.readBlobAsync(blobName, decryptingCompletionListener);
} catch (Exception e) {
listener.onFailure(e);
}
}

@Override
Expand Down Expand Up @@ -108,4 +115,58 @@ public InputStreamContainer provideStream(int partNumber) throws IOException {
}

}

/**
* DecryptedReadContext decrypts the encrypted {@link ReadContext} by acting as a transformation wrapper around
* the encrypted object
* @param <T> Encryption Metadata / CryptoContext for the {@link CryptoHandler} instance
* @param <U> Parsed Encryption Metadata / CryptoContext for the {@link CryptoHandler} instance
*/
static class DecryptedReadContext<T, U> extends ReadContext {

private final CryptoHandler<T, U> cryptoHandler;
private final U cryptoContext;
private Long blobSize;

public DecryptedReadContext(ReadContext readContext, CryptoHandler<T, U> cryptoHandler, U cryptoContext) {
super(readContext);
this.cryptoHandler = cryptoHandler;
this.cryptoContext = cryptoContext;
}

@Override
public long getBlobSize() {
// initializes the value lazily
if (blobSize == null) {
this.blobSize = this.cryptoHandler.estimateDecryptedLength(cryptoContext, super.getBlobSize());
}
return this.blobSize;
}

@Override
public List<InputStreamContainer> getPartStreams() {
return super.getPartStreams().stream().map(this::decryptInputStreamContainer).collect(Collectors.toList());
}

/**
* Transforms an encrypted {@link InputStreamContainer} to a decrypted instance
* @param inputStreamContainer encrypted input stream container instance
* @return decrypted input stream container instance
*/
private InputStreamContainer decryptInputStreamContainer(InputStreamContainer inputStreamContainer) {
long startOfStream = inputStreamContainer.getOffset();
long endOfStream = startOfStream + inputStreamContainer.getContentLength() - 1;
DecryptedRangedStreamProvider decryptedStreamProvider = cryptoHandler.createDecryptingStreamOfRange(
cryptoContext,
startOfStream,
endOfStream
);

long adjustedPos = decryptedStreamProvider.getAdjustedRange()[0];
long adjustedLength = decryptedStreamProvider.getAdjustedRange()[1] - adjustedPos + 1;
final InputStream decryptedStream = decryptedStreamProvider.getDecryptedStreamProvider()
.apply(inputStreamContainer.getInputStream());
return new InputStreamContainer(decryptedStream, adjustedLength, adjustedPos);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public InputStream readBlob(String blobName) throws IOException {
return cryptoHandler.createDecryptingStream(inputStream);
}

private EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) {
EncryptedHeaderContentSupplier getEncryptedHeaderContentSupplier(String blobName) {
return (start, end) -> {
byte[] buffer;
int length = (int) (end - start + 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ public ReadContext(long blobSize, List<InputStreamContainer> partStreams, String
this.blobChecksum = blobChecksum;
}

public ReadContext(ReadContext readContext) {
this.blobSize = readContext.blobSize;
this.partStreams = readContext.partStreams;
this.blobChecksum = readContext.blobChecksum;
}

public String getBlobChecksum() {
return blobChecksum;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.common.blobstore;

import org.opensearch.common.Randomness;
import org.opensearch.common.blobstore.stream.read.ReadContext;
import org.opensearch.common.blobstore.stream.read.listener.ListenerTestUtils;
import org.opensearch.common.crypto.CryptoHandler;
import org.opensearch.common.crypto.DecryptedRangedStreamProvider;
import org.opensearch.common.io.InputStreamContainer;
import org.opensearch.core.action.ActionListener;
import org.opensearch.test.OpenSearchTestCase;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.List;
import java.util.function.UnaryOperator;

import org.mockito.Mockito;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class AsyncMultiStreamEncryptedBlobContainerTests extends OpenSearchTestCase {

// Tests the happy path scenario for decrypting a read context
@SuppressWarnings("unchecked")
public void testReadBlobAsync() throws Exception {
String testBlobName = "testBlobName";
int size = 100;

// Mock objects needed for the test
AsyncMultiStreamBlobContainer blobContainer = mock(AsyncMultiStreamBlobContainer.class);
CryptoHandler<Object, Object> cryptoHandler = mock(CryptoHandler.class);
Object cryptoContext = mock(Object.class);
when(cryptoHandler.loadEncryptionMetadata(any())).thenReturn(cryptoContext);
when(cryptoHandler.estimateDecryptedLength(any(), anyLong())).thenReturn((long) size);
long[] adjustedRanges = { 0, size - 1 };
DecryptedRangedStreamProvider rangedStreamProvider = new DecryptedRangedStreamProvider(adjustedRanges, UnaryOperator.identity());
when(cryptoHandler.createDecryptingStreamOfRange(eq(cryptoContext), anyLong(), anyLong())).thenReturn(rangedStreamProvider);

// Objects needed for API call
final byte[] data = new byte[size];
Randomness.get().nextBytes(data);
final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0);
final ListenerTestUtils.CountingCompletionListener<ReadContext> completionListener =
new ListenerTestUtils.CountingCompletionListener<>();
final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null);

Mockito.doAnswer(invocation -> {
ActionListener<ReadContext> readContextActionListener = invocation.getArgument(1);
readContextActionListener.onResponse(readContext);
return null;
}).when(blobContainer).readBlobAsync(eq(testBlobName), any());

AsyncMultiStreamEncryptedBlobContainer<Object, Object> asyncMultiStreamEncryptedBlobContainer =
new AsyncMultiStreamEncryptedBlobContainer<>(blobContainer, cryptoHandler);
asyncMultiStreamEncryptedBlobContainer.readBlobAsync(testBlobName, completionListener);

// Assert results
ReadContext response = completionListener.getResponse();
assertEquals(0, completionListener.getFailureCount());
assertEquals(1, completionListener.getResponseCount());
assertNull(completionListener.getException());

assertTrue(response instanceof AsyncMultiStreamEncryptedBlobContainer.DecryptedReadContext);
assertEquals(1, response.getNumberOfParts());
assertEquals(size, response.getBlobSize());

InputStreamContainer responseContainer = response.getPartStreams().get(0);
assertEquals(0, responseContainer.getOffset());
assertEquals(size, responseContainer.getContentLength());
assertEquals(100, responseContainer.getInputStream().available());
}

// Tests the exception scenario for decrypting a read context
@SuppressWarnings("unchecked")
public void testReadBlobAsyncException() throws Exception {
String testBlobName = "testBlobName";
int size = 100;

// Mock objects needed for the test
AsyncMultiStreamBlobContainer blobContainer = mock(AsyncMultiStreamBlobContainer.class);
CryptoHandler<Object, Object> cryptoHandler = mock(CryptoHandler.class);
when(cryptoHandler.loadEncryptionMetadata(any())).thenThrow(new IOException());

// Objects needed for API call
final byte[] data = new byte[size];
Randomness.get().nextBytes(data);
final InputStreamContainer inputStreamContainer = new InputStreamContainer(new ByteArrayInputStream(data), data.length, 0);
final ListenerTestUtils.CountingCompletionListener<ReadContext> completionListener =
new ListenerTestUtils.CountingCompletionListener<>();
final ReadContext readContext = new ReadContext(size, List.of(inputStreamContainer), null);

Mockito.doAnswer(invocation -> {
ActionListener<ReadContext> readContextActionListener = invocation.getArgument(1);
readContextActionListener.onResponse(readContext);
return null;
}).when(blobContainer).readBlobAsync(eq(testBlobName), any());

AsyncMultiStreamEncryptedBlobContainer<Object, Object> asyncMultiStreamEncryptedBlobContainer =
new AsyncMultiStreamEncryptedBlobContainer<>(blobContainer, cryptoHandler);
asyncMultiStreamEncryptedBlobContainer.readBlobAsync(testBlobName, completionListener);

// Assert results
assertEquals(1, completionListener.getFailureCount());
assertEquals(0, completionListener.getResponseCount());
assertNull(completionListener.getResponse());
assertTrue(completionListener.getException() instanceof IOException);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public class ListenerTestUtils {
* CountingCompletionListener acts as a verification instance for wrapping listener based calls.
* Keeps track of the last response, failure and count of response and failure invocations.
*/
static class CountingCompletionListener<T> implements ActionListener<T> {
public static class CountingCompletionListener<T> implements ActionListener<T> {
private int responseCount;
private int failureCount;
private T response;
Expand Down

0 comments on commit 814edb5

Please sign in to comment.