Skip to content

Commit

Permalink
Execute ModelEnforcements in TransformExecutor
Browse files Browse the repository at this point in the history
This allows a configurable application of Model Enforcement based on the
class of transform being executed, both before and after an element is
processed and after the transform completes.
  • Loading branch information
tgroh committed Apr 1, 2016
1 parent 6585f2c commit b1cc366
Show file tree
Hide file tree
Showing 5 changed files with 318 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
*/
interface CompletionCallback {
/**
* Handle a successful result.
* Handle a successful result, returning the committed outputs of the result.
*/
void handleResult(CommittedBundle<?> inputBundle, InProcessTransformResult result);
Iterable<? extends CommittedBundle<?>> handleResult(
CommittedBundle<?> inputBundle, InProcessTransformResult result);

/**
* Handle a result that terminated abnormally due to the provided {@link Throwable}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.cloud.dataflow.sdk.runners.inprocess.InMemoryWatermarkManager.FiredTimers;
import com.google.cloud.dataflow.sdk.runners.inprocess.InProcessPipelineRunner.CommittedBundle;
import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform;
import com.google.cloud.dataflow.sdk.transforms.PTransform;
import com.google.cloud.dataflow.sdk.util.KeyedWorkItem;
import com.google.cloud.dataflow.sdk.util.KeyedWorkItems;
import com.google.cloud.dataflow.sdk.util.TimeDomain;
Expand Down Expand Up @@ -60,6 +61,10 @@ final class ExecutorServiceParallelExecutor implements InProcessExecutor {
private final Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> valueToConsumers;
private final Set<PValue> keyedPValues;
private final TransformEvaluatorRegistry registry;
@SuppressWarnings("rawtypes")
private final Map<Class<? extends PTransform>, Collection<ModelEnforcementFactory>>
transformEnforcements;

private final InProcessEvaluationContext evaluationContext;

private final ConcurrentMap<StepAndKey, TransformExecutorService> currentEvaluations;
Expand All @@ -78,21 +83,26 @@ public static ExecutorServiceParallelExecutor create(
Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> valueToConsumers,
Set<PValue> keyedPValues,
TransformEvaluatorRegistry registry,
@SuppressWarnings("rawtypes")
Map<Class<? extends PTransform>, Collection<ModelEnforcementFactory>> transformEnforcements,
InProcessEvaluationContext context) {
return new ExecutorServiceParallelExecutor(
executorService, valueToConsumers, keyedPValues, registry, context);
executorService, valueToConsumers, keyedPValues, registry, transformEnforcements, context);
}

private ExecutorServiceParallelExecutor(
ExecutorService executorService,
Map<PValue, Collection<AppliedPTransform<?, ?, ?>>> valueToConsumers,
Set<PValue> keyedPValues,
TransformEvaluatorRegistry registry,
@SuppressWarnings("rawtypes")
Map<Class<? extends PTransform>, Collection<ModelEnforcementFactory>> transformEnforcements,
InProcessEvaluationContext context) {
this.executorService = executorService;
this.valueToConsumers = valueToConsumers;
this.keyedPValues = keyedPValues;
this.registry = registry;
this.transformEnforcements = transformEnforcements;
this.evaluationContext = context;

currentEvaluations = new ConcurrentHashMap<>();
Expand Down Expand Up @@ -126,16 +136,29 @@ private <T> void evaluateBundle(
@Nullable final CommittedBundle<T> bundle,
final CompletionCallback onComplete) {
TransformExecutorService transformExecutor;

if (bundle != null && isKeyed(bundle.getPCollection())) {
final StepAndKey stepAndKey =
StepAndKey.of(transform, bundle == null ? null : bundle.getKey());
transformExecutor = getSerialExecutorService(stepAndKey);
} else {
transformExecutor = parallelExecutorService;
}

Collection<ModelEnforcementFactory> enforcements =
MoreObjects.firstNonNull(
transformEnforcements.get(transform.getTransform().getClass()),
Collections.<ModelEnforcementFactory>emptyList());

TransformExecutor<T> callable =
TransformExecutor.create(
registry, evaluationContext, bundle, transform, onComplete, transformExecutor);
registry,
enforcements,
evaluationContext,
bundle,
transform,
onComplete,
transformExecutor);
transformExecutor.schedule(callable);
}

Expand Down Expand Up @@ -176,12 +199,14 @@ public void awaitCompletion() throws Throwable {
*/
private class DefaultCompletionCallback implements CompletionCallback {
@Override
public void handleResult(CommittedBundle<?> inputBundle, InProcessTransformResult result) {
public Iterable<? extends CommittedBundle<?>> handleResult(
CommittedBundle<?> inputBundle, InProcessTransformResult result) {
Iterable<? extends CommittedBundle<?>> resultBundles =
evaluationContext.handleResult(inputBundle, Collections.<TimerData>emptyList(), result);
for (CommittedBundle<?> outputBundle : resultBundles) {
allUpdates.offer(ExecutorUpdate.fromBundle(outputBundle));
}
return resultBundles;
}

@Override
Expand All @@ -204,12 +229,14 @@ private TimerCompletionCallback(Iterable<TimerData> timers) {
}

@Override
public void handleResult(CommittedBundle<?> inputBundle, InProcessTransformResult result) {
public Iterable<? extends CommittedBundle<?>> handleResult(
CommittedBundle<?> inputBundle, InProcessTransformResult result) {
Iterable<? extends CommittedBundle<?>> resultBundles =
evaluationContext.handleResult(inputBundle, timers, result);
for (CommittedBundle<?> outputBundle : resultBundles) {
allUpdates.offer(ExecutorUpdate.fromBundle(outputBundle));
}
return resultBundles;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.joda.time.Instant;

import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -243,6 +244,7 @@ public InProcessPipelineResult run(Pipeline pipeline) {
consumerTrackingVisitor.getValueToConsumers(),
keyedPValueVisitor.getKeyedPValues(),
TransformEvaluatorRegistry.defaultRegistry(),
defaultModelEnforcements(options),
context);
executor.start(consumerTrackingVisitor.getRootTransforms());

Expand All @@ -262,6 +264,11 @@ public InProcessPipelineResult run(Pipeline pipeline) {
return result;
}

private Map<Class<? extends PTransform>, Collection<ModelEnforcementFactory>>
defaultModelEnforcements(InProcessPipelineOptions options) {
return Collections.emptyMap();
}

/**
* The result of running a {@link Pipeline} with the {@link InProcessPipelineRunner}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import com.google.cloud.dataflow.sdk.util.WindowedValue;
import com.google.common.base.Throwables;

import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.Callable;

import javax.annotation.Nullable;
Expand All @@ -35,13 +37,15 @@
class TransformExecutor<T> implements Callable<InProcessTransformResult> {
public static <T> TransformExecutor<T> create(
TransformEvaluatorFactory factory,
Iterable<? extends ModelEnforcementFactory> modelEnforcements,
InProcessEvaluationContext evaluationContext,
CommittedBundle<T> inputBundle,
AppliedPTransform<?, ?, ?> transform,
CompletionCallback completionCallback,
TransformExecutorService transformEvaluationState) {
return new TransformExecutor<>(
factory,
modelEnforcements,
evaluationContext,
inputBundle,
transform,
Expand All @@ -50,6 +54,8 @@ public static <T> TransformExecutor<T> create(
}

private final TransformEvaluatorFactory evaluatorFactory;
private final Iterable<? extends ModelEnforcementFactory> modelEnforcements;

private final InProcessEvaluationContext evaluationContext;

/** The transform that will be evaluated. */
Expand All @@ -64,12 +70,14 @@ public static <T> TransformExecutor<T> create(

private TransformExecutor(
TransformEvaluatorFactory factory,
Iterable<? extends ModelEnforcementFactory> modelEnforcements,
InProcessEvaluationContext evaluationContext,
CommittedBundle<T> inputBundle,
AppliedPTransform<?, ?, ?> transform,
CompletionCallback completionCallback,
TransformExecutorService transformEvaluationState) {
this.evaluatorFactory = factory;
this.modelEnforcements = modelEnforcements;
this.evaluationContext = evaluationContext;

this.inputBundle = inputBundle;
Expand All @@ -84,15 +92,17 @@ private TransformExecutor(
public InProcessTransformResult call() {
this.thread = Thread.currentThread();
try {
Collection<ModelEnforcement<T>> enforcements = new ArrayList<>();
for (ModelEnforcementFactory enforcementFactory : modelEnforcements) {
ModelEnforcement<T> enforcement = enforcementFactory.forBundle(inputBundle, transform);
enforcements.add(enforcement);
}
TransformEvaluator<T> evaluator =
evaluatorFactory.forApplication(transform, inputBundle, evaluationContext);
if (inputBundle != null) {
for (WindowedValue<T> value : inputBundle.getElements()) {
evaluator.processElement(value);
}
}
InProcessTransformResult result = evaluator.finishBundle();
onComplete.handleResult(inputBundle, result);

processElements(evaluator, enforcements);

InProcessTransformResult result = finishBundle(evaluator, enforcements);
return result;
} catch (Throwable t) {
onComplete.handleThrowable(inputBundle, t);
Expand All @@ -103,6 +113,46 @@ public InProcessTransformResult call() {
}
}

/**
* Processes all the elements in the input bundle using the transform evaluator, applying any
* necessary {@link ModelEnforcement ModelEnforcements}.
*/
private void processElements(
TransformEvaluator<T> evaluator, Collection<ModelEnforcement<T>> enforcements)
throws Exception {
if (inputBundle != null) {
for (WindowedValue<T> value : inputBundle.getElements()) {
for (ModelEnforcement<T> enforcement : enforcements) {
enforcement.beforeElement(value);
}

evaluator.processElement(value);

for (ModelEnforcement<T> enforcement : enforcements) {
enforcement.afterElement(value);
}
}
}
}

/**
* Finishes processing the input bundle and commit the result using the
* {@link CompletionCallback}, applying any {@link ModelEnforcement} if necessary.
*
* @return the {@link InProcessTransformResult} produced by
* {@link TransformEvaluator#finishBundle()}
*/
private InProcessTransformResult finishBundle(
TransformEvaluator<T> evaluator, Collection<ModelEnforcement<T>> enforcements)
throws Exception {
InProcessTransformResult result = evaluator.finishBundle();
Iterable<? extends CommittedBundle<?>> outputs = onComplete.handleResult(inputBundle, result);
for (ModelEnforcement<T> enforcement : enforcements) {
enforcement.afterFinish(inputBundle, result, outputs);
}
return result;
}

/**
* If this {@link TransformExecutor} is currently executing, return the thread it is executing in.
* Otherwise, return null.
Expand Down
Loading

0 comments on commit b1cc366

Please sign in to comment.