Skip to content

Commit

Permalink
Implement pass through mode UNION ALL on Presto-on-Spark
Browse files Browse the repository at this point in the history
Instead of round robin shuffle at row level for UNION ALL, all rows on
the same partition are passed through current partition and sent to
next fragment where partition id is determined by a modular hash
function.
  • Loading branch information
viczhang861 authored and arhimondr committed Sep 23, 2020
1 parent 4d91a60 commit d57bbfd
Show file tree
Hide file tree
Showing 7 changed files with 292 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public static void setDefaults(FeaturesConfig config)
config.setColocatedJoinsEnabled(true);
config.setRedistributeWrites(false);
config.setScaleWriters(false);
config.setPreferDistributedUnion(false);
config.setPreferDistributedUnion(true);
config.setForceSingleNodeOutput(false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import java.util.OptionalInt;
import java.util.function.Function;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;
Expand All @@ -56,13 +57,19 @@ public static class PrestoSparkRowOutputFactory
false,
OptionalInt.empty());

private final Optional<OutputPartitioning> preDeterminedPartition;

private final PrestoSparkOutputBuffer<PrestoSparkRowBatch> outputBuffer;
private final DataSize targetAverageRowSize;

public PrestoSparkRowOutputFactory(PrestoSparkOutputBuffer<PrestoSparkRowBatch> outputBuffer, DataSize targetAverageRowSize)
public PrestoSparkRowOutputFactory(
PrestoSparkOutputBuffer<PrestoSparkRowBatch> outputBuffer,
DataSize targetAverageRowSize,
Optional<OutputPartitioning> preDeterminedPartition)
{
this.outputBuffer = requireNonNull(outputBuffer, "outputBuffer is null");
this.targetAverageRowSize = requireNonNull(targetAverageRowSize, "targetAverageRowSize is null");
this.preDeterminedPartition = requireNonNull(preDeterminedPartition, "preDeterminedPartition is null");
}

@Override
Expand All @@ -74,7 +81,7 @@ public OperatorFactory createOutputOperator(
Optional<OutputPartitioning> outputPartitioning,
PagesSerdeFactory serdeFactory)
{
OutputPartitioning partitioning = outputPartitioning.orElse(SINGLE_PARTITION);
OutputPartitioning partitioning = outputPartitioning.orElse(preDeterminedPartition.orElse(SINGLE_PARTITION));
return new PrestoSparkRowOutputOperatorFactory(
operatorId,
planNodeId,
Expand Down Expand Up @@ -107,6 +114,33 @@ public int getPartition(Page page, int position)
}
}

public static class PreDeterminedPartitionFunction
implements PartitionFunction
{
private final int partitionId;
private final int partitionCount;

public PreDeterminedPartitionFunction(int partitionId, int partitionCount)
{
checkArgument(partitionId >= 0 && partitionId < partitionCount,
"partitionId should be non-negative and less than partitionCount");
this.partitionId = partitionId;
this.partitionCount = partitionCount;
}

@Override
public int getPartitionCount()
{
return partitionCount;
}

@Override
public int getPartition(Page page, int position)
{
return partitionId;
}
}

public static class PrestoSparkRowOutputOperatorFactory
implements OperatorFactory
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import com.facebook.presto.spark.classloader_interface.SerializedTaskInfo;
import com.facebook.presto.spark.execution.PrestoSparkPageOutputOperator.PrestoSparkPageOutputFactory;
import com.facebook.presto.spark.execution.PrestoSparkRowBatch.RowTupleSupplier;
import com.facebook.presto.spark.execution.PrestoSparkRowOutputOperator.PreDeterminedPartitionFunction;
import com.facebook.presto.spark.execution.PrestoSparkRowOutputOperator.PrestoSparkRowOutputFactory;
import com.facebook.presto.spi.ConnectorSplit;
import com.facebook.presto.spi.memory.MemoryPoolId;
Expand All @@ -66,6 +67,7 @@
import com.facebook.presto.spiller.SpillSpaceTracker;
import com.facebook.presto.sql.planner.LocalExecutionPlanner;
import com.facebook.presto.sql.planner.LocalExecutionPlanner.LocalExecutionPlan;
import com.facebook.presto.sql.planner.OutputPartitioning;
import com.facebook.presto.sql.planner.PlanFragment;
import com.facebook.presto.sql.planner.plan.PlanFragmentId;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
Expand All @@ -88,13 +90,15 @@
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.OptionalLong;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.Executor;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ScheduledExecutorService;

import static com.facebook.presto.SystemSessionProperties.getHashPartitionCount;
import static com.facebook.presto.execution.TaskState.FAILED;
import static com.facebook.presto.execution.TaskStatus.STARTING_VERSION;
import static com.facebook.presto.execution.buffer.BufferState.FINISHED;
Expand All @@ -103,6 +107,7 @@
import static com.facebook.presto.spark.util.PrestoSparkUtils.compress;
import static com.facebook.presto.spark.util.PrestoSparkUtils.decompress;
import static com.facebook.presto.spark.util.PrestoSparkUtils.toPrestoSparkSerializedPage;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
import static com.facebook.presto.util.Failures.toFailures;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
Expand Down Expand Up @@ -373,11 +378,23 @@ public <T extends PrestoSparkTaskOutput> IPrestoSparkTaskExecutor<T> doCreate(
() -> queryContext.getTaskContextByTaskId(taskId).localSystemMemoryContext(),
notificationExecutor);
PagesSerde pagesSerde = new PagesSerde(blockEncodingManager, Optional.empty(), Optional.empty(), Optional.empty());

Optional<OutputPartitioning> preDeterminedPartition = Optional.empty();
if (fragment.getPartitioningScheme().getPartitioning().getHandle().equals(FIXED_ARBITRARY_DISTRIBUTION)) {
int partitionCount = getHashPartitionCount(session);
preDeterminedPartition = Optional.of(new OutputPartitioning(
new PreDeterminedPartitionFunction(partitionId % partitionCount, partitionCount),
ImmutableList.of(),
ImmutableList.of(),
false,
OptionalInt.empty()));
}
Output<T> output = configureOutput(
outputType,
pagesSerde,
memoryManager,
getShuffleOutputTargetAverageRowSize(session));
getShuffleOutputTargetAverageRowSize(session),
preDeterminedPartition);
PrestoSparkOutputBuffer<?> outputBuffer = output.getOutputBuffer();

LocalExecutionPlan localExecutionPlan = localExecutionPlanner.plan(
Expand Down Expand Up @@ -456,11 +473,12 @@ private static <T extends PrestoSparkTaskOutput> Output<T> configureOutput(
Class<T> outputType,
PagesSerde pagesSerde,
OutputBufferMemoryManager memoryManager,
DataSize targetAverageRowSize)
DataSize targetAverageRowSize,
Optional<OutputPartitioning> preDeterminedPartition)
{
if (outputType.equals(PrestoSparkMutableRow.class)) {
PrestoSparkOutputBuffer<PrestoSparkRowBatch> outputBuffer = new PrestoSparkOutputBuffer<>(memoryManager);
OutputFactory outputFactory = new PrestoSparkRowOutputFactory(outputBuffer, targetAverageRowSize);
OutputFactory outputFactory = new PrestoSparkRowOutputFactory(outputBuffer, targetAverageRowSize, preDeterminedPartition);
OutputSupplier<T> outputSupplier = (OutputSupplier<T>) new RowOutputSupplier(outputBuffer);
return new Output<>(OutputBufferType.SPARK_ROW_OUTPUT_BUFFER, outputBuffer, outputFactory, outputSupplier);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,6 @@ public <T extends PrestoSparkTaskOutput> JavaPairRDD<MutablePartitionId, T> crea
throw new PrestoException(NOT_SUPPORTED, "Automatic writers scaling is not supported by Presto on Spark");
}

// Currently remote round robin exchange is only used in two cases
// - Redistribute writes:
// Originally introduced to avoid skewed table writes. Makes sense with streaming exchanges
// as those are very cheap. Since spark has to write the data to disk anyway the optimization
// doesn't make much sense in Presto on Spark context, thus it is always disabled.
// - Some corner cases of UNION (e.g.: broadcasted UNION ALL)
// Since round robin exchange is very costly on Spark (and potentially a correctness hazard)
// such unions are always planned with Gather (SINGLE_DISTRIBUTION)
checkArgument(!partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION), "FIXED_ARBITRARY_DISTRIBUTION is not supported");

checkArgument(!partitioning.equals(COORDINATOR_DISTRIBUTION), "COORDINATOR_DISTRIBUTION fragment must be run on the driver");
checkArgument(!partitioning.equals(FIXED_BROADCAST_DISTRIBUTION), "FIXED_BROADCAST_DISTRIBUTION can only be set as an output partitioning scheme, and not as a fragment distribution");
checkArgument(!partitioning.equals(FIXED_PASSTHROUGH_DISTRIBUTION), "FIXED_PASSTHROUGH_DISTRIBUTION can only be set as local exchange partitioning");
Expand All @@ -179,6 +169,7 @@ public <T extends PrestoSparkTaskOutput> JavaPairRDD<MutablePartitionId, T> crea

if (partitioning.equals(SINGLE_DISTRIBUTION) ||
partitioning.equals(FIXED_HASH_DISTRIBUTION) ||
partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION) ||
partitioning.equals(SOURCE_DISTRIBUTION) ||
partitioning.getConnectorId().isPresent()) {
for (RemoteSourceNode remoteSource : fragment.getRemoteSourceNodes()) {
Expand Down Expand Up @@ -217,6 +208,15 @@ private PlanFragment configureOutputPartitioning(Session session, PlanFragment f
int hashPartitionCount = getHashPartitionCount(session);
return fragment.withBucketToPartition(Optional.of(IntStream.range(0, hashPartitionCount).toArray()));
}
// FIXED_ARBITRARY_DISTRIBUTION is used for UNION ALL
// UNION ALL inputs could be source inputs or shuffle inputs
if (outputPartitioningHandle.equals(FIXED_ARBITRARY_DISTRIBUTION)) {
// given modular hash function, partition count could be arbitrary size
// simply reuse hash_partition_count for convenience
// it can also be set by a separate session property if needed
int partitionCount = getHashPartitionCount(session);
return fragment.withBucketToPartition(Optional.of(IntStream.range(0, partitionCount).toArray()));
}
if (outputPartitioningHandle.getConnectorId().isPresent()) {
int connectorPartitionCount = getPartitionCount(session, outputPartitioningHandle);
return fragment.withBucketToPartition(Optional.of(IntStream.range(0, connectorPartitionCount).toArray()));
Expand All @@ -240,7 +240,7 @@ private Partitioner createPartitioner(Session session, PartitioningHandle partit
if (partitioning.equals(SINGLE_DISTRIBUTION)) {
return new PrestoSparkPartitioner(1);
}
if (partitioning.equals(FIXED_HASH_DISTRIBUTION)) {
if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION)) {
int hashPartitionCount = getHashPartitionCount(session);
return new PrestoSparkPartitioner(hashPartitionCount);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ public PrestoSparkQueryRunner(String defaultCatalog)
DRIVER,
ImmutableMap.of(
"presto.version", "testversion",
"query.hash-partition-count", Integer.toString(NODE_COUNT * 2),
"prefer-distributed-union", "false"),
"query.hash-partition-count", Integer.toString(NODE_COUNT * 2)),
ImmutableMap.of(),
Optional.empty(),
new SqlParserOptions(),
Expand Down
Loading

0 comments on commit d57bbfd

Please sign in to comment.