Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException;
import com.google.cloud.dataflow.sdk.coders.Coder;
import com.google.cloud.dataflow.sdk.coders.CoderRegistry;
import com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn;
import com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn;
import com.google.cloud.dataflow.sdk.transforms.CombineFnBase.GlobalCombineFn;
import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.CombineFnWithContext;
import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.Context;
import com.google.cloud.dataflow.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext;
import com.google.cloud.dataflow.sdk.util.state.StateContext;
Expand All @@ -46,6 +49,60 @@ public class CombineFnUtil {
return new NonSerializableBoundedKeyedCombineFn<>(combineFn, context);
}

/**
* Return a {@link CombineFnWithContext} from the given {@link GlobalCombineFn}.
*/
public static <InputT, AccumT, OutputT>
CombineFnWithContext<InputT, AccumT, OutputT> toFnWithContext(
GlobalCombineFn<InputT, AccumT, OutputT> globalCombineFn) {
if (globalCombineFn instanceof CombineFnWithContext) {
@SuppressWarnings("unchecked")
CombineFnWithContext<InputT, AccumT, OutputT> combineFnWithContext =
(CombineFnWithContext<InputT, AccumT, OutputT>) globalCombineFn;
return combineFnWithContext;
} else {
@SuppressWarnings("unchecked")
final CombineFn<InputT, AccumT, OutputT> combineFn =
(CombineFn<InputT, AccumT, OutputT>) globalCombineFn;
return new CombineFnWithContext<InputT, AccumT, OutputT>() {
@Override
public AccumT createAccumulator(Context c) {
return combineFn.createAccumulator();
}
@Override
public AccumT addInput(AccumT accumulator, InputT input, Context c) {
return combineFn.addInput(accumulator, input);
}
@Override
public AccumT mergeAccumulators(Iterable<AccumT> accumulators, Context c) {
return combineFn.mergeAccumulators(accumulators);
}
@Override
public OutputT extractOutput(AccumT accumulator, Context c) {
return combineFn.extractOutput(accumulator);
}
@Override
public AccumT compact(AccumT accumulator, Context c) {
return combineFn.compact(accumulator);
}
@Override
public OutputT defaultValue() {
return combineFn.defaultValue();
}
@Override
public Coder<AccumT> getAccumulatorCoder(CoderRegistry registry, Coder<InputT> inputCoder)
throws CannotProvideCoderException {
return combineFn.getAccumulatorCoder(registry, inputCoder);
}
@Override
public Coder<OutputT> getDefaultOutputCoder(
CoderRegistry registry, Coder<InputT> inputCoder) throws CannotProvideCoderException {
return combineFn.getDefaultOutputCoder(registry, inputCoder);
}
};
}
}

private static class NonSerializableBoundedKeyedCombineFn<K, InputT, AccumT, OutputT>
extends KeyedCombineFn<K, InputT, AccumT, OutputT> {
private final KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> combineFn;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,21 @@ public static <K, StateT extends State> StateTag<K, StateT> makeSystemTagInterna
public static <K, InputT, AccumT, OutputT> StateTag<Object, BagState<AccumT>>
convertToBagTagInternal(
StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> combiningTag) {
if (!(combiningTag instanceof KeyedCombiningValueStateTag)) {
if (combiningTag instanceof KeyedCombiningValueStateTag) {
// Checked above; conversion to a bag tag depends on the provided tag being one of those
// created via the factory methods in this class.
@SuppressWarnings("unchecked")
KeyedCombiningValueStateTag<K, InputT, AccumT, OutputT> typedTag =
(KeyedCombiningValueStateTag<K, InputT, AccumT, OutputT>) combiningTag;
return typedTag.asBagTag();
} else if (combiningTag instanceof KeyedCombiningValueWithContextStateTag) {
@SuppressWarnings("unchecked")
KeyedCombiningValueWithContextStateTag<K, InputT, AccumT, OutputT> typedTag =
(KeyedCombiningValueWithContextStateTag<K, InputT, AccumT, OutputT>) combiningTag;
return typedTag.asBagTag();
} else {
throw new IllegalArgumentException("Unexpected StateTag " + combiningTag);
}
// Checked above; conversion to a bag tag depends on the provided tag being one of those
// created via the factory methods in this class.
@SuppressWarnings("unchecked")
KeyedCombiningValueStateTag<K, InputT, AccumT, OutputT> typedTag =
(KeyedCombiningValueStateTag<K, InputT, AccumT, OutputT>) combiningTag;
return typedTag.asBagTag();
}

private static class StructuredId implements Serializable {
Expand Down Expand Up @@ -413,6 +419,10 @@ public StateTag<K, AccumulatorCombiningState<InputT, AccumT, OutputT>> asKind(
return new KeyedCombiningValueWithContextStateTag<>(
id.asKind(kind), accumCoder, combineFn);
}

private StateTag<Object, BagState<AccumT>> asBagTag() {
return new BagStateTag<AccumT>(id, accumCoder);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@

import com.google.cloud.dataflow.sdk.coders.BigEndianIntegerCoder;
import com.google.cloud.dataflow.sdk.coders.Coder;
import com.google.cloud.dataflow.sdk.coders.CoderRegistry;
import com.google.cloud.dataflow.sdk.coders.VarIntCoder;
import com.google.cloud.dataflow.sdk.transforms.Combine.Holder;
import com.google.cloud.dataflow.sdk.transforms.Max;
import com.google.cloud.dataflow.sdk.transforms.Max.MaxIntegerFn;
import com.google.cloud.dataflow.sdk.transforms.Min;
import com.google.cloud.dataflow.sdk.transforms.Min.MinIntegerFn;
import com.google.cloud.dataflow.sdk.transforms.windowing.OutputTimeFns;
import com.google.cloud.dataflow.sdk.util.CombineFnUtil;

import org.junit.Test;
import org.junit.runner.RunWith;
Expand Down Expand Up @@ -80,6 +83,7 @@ public void testWatermarkBagEquality() {
assertEquals(bar, bar2);
}

@SuppressWarnings({"unchecked", "rawtypes"})
@Test
public void testCombiningValueEquality() {
MaxIntegerFn maxFn = new Max.MaxIntegerFn();
Expand All @@ -96,13 +100,74 @@ public void testCombiningValueEquality() {

// Same name, coder and combineFn
assertEquals(fooCoder1Max1, fooCoder1Max2);
assertEquals(
StateTags.convertToBagTagInternal((StateTag) fooCoder1Max1),
StateTags.convertToBagTagInternal((StateTag) fooCoder1Max2));

// Different combineFn, but we treat them as equal since we only serialize the bits.
assertEquals(fooCoder1Max1, fooCoder1Min);
assertEquals(
StateTags.convertToBagTagInternal((StateTag) fooCoder1Max1),
StateTags.convertToBagTagInternal((StateTag) fooCoder1Min));

// Different input coder coder.
assertNotEquals(fooCoder1Max1, fooCoder2Max);
assertNotEquals(
StateTags.convertToBagTagInternal((StateTag) fooCoder1Max1),
StateTags.convertToBagTagInternal((StateTag) fooCoder2Max));

// These StateTags have different IDs.
assertNotEquals(fooCoder1Max1, barCoder1Max);
assertNotEquals(
StateTags.convertToBagTagInternal((StateTag) fooCoder1Max1),
StateTags.convertToBagTagInternal((StateTag) barCoder1Max));
}

@SuppressWarnings({"unchecked", "rawtypes"})
@Test
public void testCombiningValueWithContextEquality() {
CoderRegistry registry = new CoderRegistry();
registry.registerStandardCoders();

MaxIntegerFn maxFn = new Max.MaxIntegerFn();
MinIntegerFn minFn = new Min.MinIntegerFn();

Coder<Holder<Integer>> accum1 = maxFn.getAccumulatorCoder(registry, VarIntCoder.of());
Coder<Holder<Integer>> accum2 = minFn.getAccumulatorCoder(registry, BigEndianIntegerCoder.of());

StateTag<?, ?> fooCoder1Max1 = StateTags.keyedCombiningValueWithContext(
"foo", accum1, CombineFnUtil.toFnWithContext(maxFn).<String>asKeyedFn());
StateTag<?, ?> fooCoder1Max2 = StateTags.keyedCombiningValueWithContext(
"foo", accum1, CombineFnUtil.toFnWithContext(maxFn).asKeyedFn());
StateTag<?, ?> fooCoder1Min = StateTags.keyedCombiningValueWithContext(
"foo", accum1, CombineFnUtil.toFnWithContext(minFn).asKeyedFn());

StateTag<?, ?> fooCoder2Max = StateTags.keyedCombiningValueWithContext(
"foo", accum2, CombineFnUtil.toFnWithContext(maxFn).asKeyedFn());
StateTag<?, ?> barCoder1Max = StateTags.keyedCombiningValueWithContext(
"bar", accum1, CombineFnUtil.toFnWithContext(maxFn).asKeyedFn());

// Same name, coder and combineFn
assertEquals(fooCoder1Max1, fooCoder1Max2);
assertEquals(
StateTags.convertToBagTagInternal((StateTag) fooCoder1Max1),
StateTags.convertToBagTagInternal((StateTag) fooCoder1Max2));
// Different combineFn, but we treat them as equal since we only serialize the bits.
assertEquals(fooCoder1Max1, fooCoder1Min);
assertEquals(
StateTags.convertToBagTagInternal((StateTag) fooCoder1Max1),
StateTags.convertToBagTagInternal((StateTag) fooCoder1Min));

// Different input coder coder.
assertNotEquals(fooCoder1Max1, fooCoder2Max);
assertNotEquals(
StateTags.convertToBagTagInternal((StateTag) fooCoder1Max1),
StateTags.convertToBagTagInternal((StateTag) fooCoder2Max));

// These StateTags have different IDs.
assertNotEquals(fooCoder1Max1, barCoder1Max);
assertNotEquals(
StateTags.convertToBagTagInternal((StateTag) fooCoder1Max1),
StateTags.convertToBagTagInternal((StateTag) barCoder1Max));
}
}