Skip to content

Commit

Permalink
Fix order of dwrf encryption group stats
Browse files Browse the repository at this point in the history
Stats list should store encrypted stats in the order that nodes are
listed in the encryption group.
  • Loading branch information
rschlussel committed Sep 23, 2020
1 parent d57bbfd commit c857741
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 6 deletions.
11 changes: 7 additions & 4 deletions presto-orc/src/main/java/com/facebook/presto/orc/OrcWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -610,19 +610,22 @@ private List<DataOutput> bufferFileFooter()
.collect(Collectors.toMap(Entry::getKey, entry -> utf8Slice(entry.getValue())));

List<ColumnStatistics> unencryptedStats = new ArrayList<>();
Map<Integer, List<Slice>> encryptedStats = new HashMap<>();
Map<Integer, Map<Integer, Slice>> encryptedStats = new HashMap<>();
addStatsRecursive(fileStats, 0, new HashMap<>(), unencryptedStats, encryptedStats);
Optional<DwrfEncryption> dwrfEncryption;
if (dwrfWriterEncryption.isPresent()) {
ImmutableList.Builder<EncryptionGroup> encryptionGroupBuilder = ImmutableList.builder();
List<WriterEncryptionGroup> writerEncryptionGroups = dwrfWriterEncryption.get().getWriterEncryptionGroups();
for (int i = 0; i < writerEncryptionGroups.size(); i++) {
WriterEncryptionGroup group = writerEncryptionGroups.get(i);
Map<Integer, Slice> groupStats = encryptedStats.get(i);
encryptionGroupBuilder.add(
new EncryptionGroup(
group.getNodes(),
Optional.empty(), // reader will just use key metadata from the stripe
encryptedStats.get(i)));
group.getNodes().stream()
.map(groupStats::get)
.collect(toList())));
}
dwrfEncryption = Optional.of(
new DwrfEncryption(
Expand Down Expand Up @@ -657,7 +660,7 @@ private List<DataOutput> bufferFileFooter()
return outputData;
}

private void addStatsRecursive(List<ColumnStatistics> allStats, int index, Map<Integer, List<ColumnStatistics>> nodeAndSubNodeStats, List<ColumnStatistics> unencryptedStats, Map<Integer, List<Slice>> encryptedStats)
private void addStatsRecursive(List<ColumnStatistics> allStats, int index, Map<Integer, List<ColumnStatistics>> nodeAndSubNodeStats, List<ColumnStatistics> unencryptedStats, Map<Integer, Map<Integer, Slice>> encryptedStats)
throws IOException
{
if (allStats.isEmpty()) {
Expand Down Expand Up @@ -686,7 +689,7 @@ private void addStatsRecursive(List<ColumnStatistics> allStats, int index, Map<I
}
if (isRootNode) {
Slice encryptedFileStatistics = toEncryptedFileStatistics(nodeAndSubNodeStats.get(group), group);
encryptedStats.computeIfAbsent(group, x -> new ArrayList<>()).add(encryptedFileStatistics);
encryptedStats.computeIfAbsent(group, x -> new HashMap<>()).put(index, encryptedFileStatistics);
}
}
else {
Expand Down
137 changes: 135 additions & 2 deletions presto-orc/src/test/java/com/facebook/presto/orc/TestDecryption.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import com.facebook.presto.orc.stream.OrcInputStream;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.BasicSliceInput;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.units.DataSize;
Expand Down Expand Up @@ -311,6 +312,72 @@ public void testSingleEncryptionGroupRowType()
outputColumns);
}

@Test
public void testEncryptionGroupWithMultipleTypes()
throws Exception
{
Slice iek1 = Slices.utf8Slice("iek1");
DwrfWriterEncryption dwrfWriterEncryption = new DwrfWriterEncryption(
UNKNOWN,
ImmutableList.of(
new WriterEncryptionGroup(ImmutableList.of(1, 2), iek1)));
List<Type> types = ImmutableList.of(BIGINT, VARCHAR);
List<Long> intValues = ImmutableList.copyOf(intsBetween(0, 31_234)).stream()
.map(Number::longValue)
.collect(toImmutableList());
List<String> varcharValues = ImmutableList.copyOf(intsBetween(0, 31_234)).stream()
.map(String::valueOf)
.collect(toImmutableList());

List<List<?>> values = ImmutableList.of(intValues, varcharValues);
List<Integer> outputColumns = IntStream.range(0, types.size())
.boxed()
.collect(toImmutableList());

testDecryptionRoundTrip(
types,
values,
values,
Optional.of(dwrfWriterEncryption),
ImmutableMap.of(1, iek1, 2, iek1),
ImmutableMap.of(0, BIGINT, 1, VARCHAR),
ImmutableMap.of(),
outputColumns);
}

@Test
public void testEncryptionGroupWithReversedOrderNodes()
throws Exception
{
Slice iek1 = Slices.utf8Slice("iek1");
DwrfWriterEncryption dwrfWriterEncryption = new DwrfWriterEncryption(
UNKNOWN,
ImmutableList.of(
new WriterEncryptionGroup(ImmutableList.of(2, 1), iek1)));
List<Type> types = ImmutableList.of(BIGINT, VARCHAR);
List<Long> intValues = ImmutableList.copyOf(intsBetween(0, 31_234)).stream()
.map(Number::longValue)
.collect(toImmutableList());
List<String> varcharValues = ImmutableList.copyOf(intsBetween(0, 31_234)).stream()
.map(String::valueOf)
.collect(toImmutableList());

List<List<?>> values = ImmutableList.of(intValues, varcharValues);
List<Integer> outputColumns = IntStream.range(0, types.size())
.boxed()
.collect(toImmutableList());

testDecryptionRoundTrip(
types,
values,
values,
Optional.of(dwrfWriterEncryption),
ImmutableMap.of(1, iek1, 2, iek1),
ImmutableMap.of(0, BIGINT, 1, VARCHAR),
ImmutableMap.of(),
outputColumns);
}

@Test
public void testMultipleEncryptionGroupsMultipleColumns()
throws Exception
Expand Down Expand Up @@ -560,12 +627,78 @@ private static void validateFileStatistics(File file, Optional<DwrfWriterEncrypt
.map(WriterEncryptionGroup::getNodes)
.flatMap(Collection::stream)
.forEach(node -> assertTrue(hasNoTypeStats(fileStats.get(node)), format("file stats for node %s had type stats %s", node, fileStats.get(node))));
footer.getEncryption().getEncryptionGroupsList()
.forEach(group -> assertEquals(group.getNodesCount(), group.getStatisticsCount()));
DwrfProto.Encryption encryption = footer.getEncryption();
EncryptionLibrary encryptionLibrary = new TestingEncryptionLibrary();
List<byte[]> keys = IntStream.range(0, dwrfWriterEncryption.get().getWriterEncryptionGroups().size())
.boxed()
.map(i -> {
byte[] encryptedKey = footer.getStripes(0).getKeyMetadata(i).toByteArray();
byte[] intermediateKey = dwrfWriterEncryption.get().getWriterEncryptionGroups().get(i).getIntermediateKeyMetadata().getBytes();
return encryptionLibrary.decryptData(intermediateKey, encryptedKey, 0, encryptedKey.length);
})
.collect(toImmutableList());
for (int i = 0; i < dwrfWriterEncryption.get().getWriterEncryptionGroups().size(); i++) {
validateEncryptionGroupStats(encryption.getEncryptionGroups(i), footer.getTypesList(), keys.get(i), orcDataSource, orcFileTail);
}
}
}
}

/**
* For each encryption groups a list of encrypted FileStatistics corresponding to every node in the encryption group is stored.
* The FileStatistics contains the ColumnStatistics for that node, any child nodes.
* This method validates the we have the expected types of stats in the expected order for each encryption group
**/
private static void validateEncryptionGroupStats(DwrfProto.EncryptionGroup group, List<DwrfProto.Type> typesList, byte[] key, OrcDataSource orcDataSource, OrcFileTail orcFileTail)
throws IOException
{
assertEquals(group.getNodesCount(), group.getStatisticsCount());
Optional<OrcDecompressor> decompressor = createOrcDecompressor(orcDataSource.getId(), orcFileTail.getCompressionKind(), orcFileTail.getBufferSize(), false);
for (int i = 0; i < group.getNodesCount(); i++) {
DwrfDataEncryptor decryptor = new DwrfDataEncryptor(key, new TestingEncryptionLibrary());
try (InputStream inputStream = new OrcInputStream(
orcDataSource.getId(),
new BasicSliceInput(Slices.wrappedBuffer(group.getStatistics(i).toByteArray())),
decompressor,
Optional.of(decryptor),
NOOP_ORC_AGGREGATED_MEMORY_CONTEXT,
group.getStatistics(i).size())) {
CodedInputStream input = CodedInputStream.newInstance(inputStream);
DwrfProto.FileStatistics fileStatistics = DwrfProto.FileStatistics.parseFrom(input);
int finalStatsIndex = assertStatsTypesMatch(fileStatistics, typesList, typesList.get(group.getNodes(i)), 0);
assertEquals(finalStatsIndex, fileStatistics.getStatisticsCount() - 1);
}
}
}

private static int assertStatsTypesMatch(DwrfProto.FileStatistics fileStats, List<DwrfProto.Type> types, DwrfProto.Type type, int statsIndex)
{
DwrfProto.Type.Kind kind = type.getKind();
if (kind == DwrfProto.Type.Kind.BINARY) {
assertTrue(fileStats.getStatistics(statsIndex).hasBinaryStatistics());
}
else if (kind == DwrfProto.Type.Kind.BOOLEAN) {
assertTrue(fileStats.getStatistics(statsIndex).hasBucketStatistics());
}
else if (kind == DwrfProto.Type.Kind.BYTE || kind == DwrfProto.Type.Kind.SHORT || kind == DwrfProto.Type.Kind.INT || kind == DwrfProto.Type.Kind.LONG) {
assertTrue(fileStats.getStatistics(statsIndex).hasIntStatistics());
}
else if (kind == DwrfProto.Type.Kind.FLOAT || kind == DwrfProto.Type.Kind.DOUBLE) {
assertTrue(fileStats.getStatistics(statsIndex).hasDoubleStatistics());
}
else if (kind == DwrfProto.Type.Kind.STRING) {
assertTrue(fileStats.getStatistics(statsIndex).hasStringStatistics());
}
else {
assertTrue(hasNoTypeStats(fileStats.getStatistics(statsIndex)));
}

for (int i = 0; i < type.getSubtypesCount(); i++) {
statsIndex = assertStatsTypesMatch(fileStats, types, types.get(type.getSubtypes(i)), statsIndex + 1);
}
return statsIndex;
}

private static boolean hasNoTypeStats(DwrfProto.ColumnStatistics columnStatistics)
{
return !columnStatistics.hasBinaryStatistics()
Expand Down

0 comments on commit c857741

Please sign in to comment.