Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -166,15 +166,67 @@ public <InputT, OutputT> DoFnInvoker<InputT, OutputT> invokerFor(DoFn<InputT, Ou
private static final String FN_DELEGATE_FIELD_NAME = "delegate";

/**
* A cache of constructors of generated {@link DoFnInvoker} classes, keyed by {@link DoFn} class.
* Needed because generating an invoker class is expensive, and to avoid generating an excessive
* number of classes consuming PermGen memory.
* Cache key for DoFnInvoker constructors that includes both the DoFn class and its generic type
* parameters to prevent collisions when the same DoFn class is used with different generic types.
*/
private static final class InvokerCacheKey {
private final Class<? extends DoFn<?, ?>> fnClass;
private final TypeDescriptor<?> inputType;
private final TypeDescriptor<?> outputType;

InvokerCacheKey(
Class<? extends DoFn<?, ?>> fnClass,
TypeDescriptor<?> inputType,
TypeDescriptor<?> outputType) {
this.fnClass = fnClass;
this.inputType = inputType;
this.outputType = outputType;
}

@Override
public boolean equals(@Nullable Object o) {
if (this == o) {
return true;
}
if (!(o instanceof InvokerCacheKey)) {
return false;
}
InvokerCacheKey that = (InvokerCacheKey) o;
return fnClass.equals(that.fnClass)
&& inputType.equals(that.inputType)
&& outputType.equals(that.outputType);
}

@Override
public int hashCode() {
int result = fnClass.hashCode();
result = 31 * result + inputType.hashCode();
result = 31 * result + outputType.hashCode();
return result;
}

@Override
public String toString() {
return String.format(
"InvokerCacheKey{fnClass=%s, inputType=%s, outputType=%s}",
fnClass.getName(), inputType, outputType);
}
}

/**
* A cache of constructors of generated {@link DoFnInvoker} classes, keyed by {@link DoFn} class
* and its generic type parameters. Needed because generating an invoker class is expensive, and
* to avoid generating an excessive number of classes consuming PermGen memory.
*
* <p>The cache key includes generic type information to prevent collisions when the same DoFn
* class is used with different generic types (e.g., MyDoFn&lt;String&gt; vs
* MyDoFn&lt;Integer&gt;).
*
* <p>Note that special care must be taken to enumerate this object as concurrent hash maps are <a
* href="https://docs.oracle.com/javase/8/docs/api/java/util/concurrent/package-summary.html#Weakly>weakly
* consistent</a>.
*/
private final Map<Class<?>, Constructor<?>> byteBuddyInvokerConstructorCache =
private final Map<InvokerCacheKey, Constructor<?>> byteBuddyInvokerConstructorCache =
new ConcurrentHashMap<>();

private ByteBuddyDoFnInvokerFactory() {}
Expand Down Expand Up @@ -265,11 +317,24 @@ public <InputT, OutputT> DoFnInvoker<InputT, OutputT> newByteBuddyInvoker(
signature.fnClass(),
fn.getClass());

// Extract input and output type descriptors from the DoFn instance
// Fall back to Object.class if the type descriptors are null (e.g., for mocked DoFn instances)
@SuppressWarnings("unchecked")
TypeDescriptor<InputT> inputType = fn.getInputTypeDescriptor();
if (inputType == null) {
inputType = (TypeDescriptor<InputT>) TypeDescriptor.of(Object.class);
}
@SuppressWarnings("unchecked")
TypeDescriptor<OutputT> outputType = fn.getOutputTypeDescriptor();
if (outputType == null) {
outputType = (TypeDescriptor<OutputT>) TypeDescriptor.of(Object.class);
}

try {
@SuppressWarnings("unchecked")
DoFnInvokerBase<InputT, OutputT, DoFn<InputT, OutputT>> invoker =
(DoFnInvokerBase<InputT, OutputT, DoFn<InputT, OutputT>>)
getByteBuddyInvokerConstructor(signature).newInstance(fn);
getByteBuddyInvokerConstructor(signature, inputType, outputType).newInstance(fn);

if (signature.onTimerMethods() != null) {
for (OnTimerMethod onTimerMethod : signature.onTimerMethods().values()) {
Expand Down Expand Up @@ -297,19 +362,24 @@ public <InputT, OutputT> DoFnInvoker<InputT, OutputT> newByteBuddyInvoker(
}

/**
* Returns a generated constructor for a {@link DoFnInvoker} for the given {@link DoFn} class.
* Returns a generated constructor for a {@link DoFnInvoker} for the given {@link DoFnSignature}
* and specific generic types.
*
* <p>These are cached such that at most one {@link DoFnInvoker} class exists for a given {@link
* DoFn} class.
* DoFn} class with specific generic type parameters. Different generic instantiations of the same
* DoFn class will have separate cached invoker classes.
*/
private Constructor<?> getByteBuddyInvokerConstructor(DoFnSignature signature) {
private Constructor<?> getByteBuddyInvokerConstructor(
DoFnSignature signature, TypeDescriptor<?> inputType, TypeDescriptor<?> outputType) {
Class<? extends DoFn<?, ?>> fnClass = signature.fnClass();
InvokerCacheKey cacheKey = new InvokerCacheKey(fnClass, inputType, outputType);
return byteBuddyInvokerConstructorCache.computeIfAbsent(
fnClass,
clazz -> {
Class<? extends DoFnInvoker<?, ?>> invokerClass = generateInvokerClass(signature);
cacheKey,
key -> {
Class<? extends DoFnInvoker<?, ?>> invokerClass =
generateInvokerClass(signature, inputType, outputType);
try {
return invokerClass.getConstructor(clazz);
return invokerClass.getConstructor(fnClass);
} catch (IllegalArgumentException | NoSuchMethodException | SecurityException e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -457,18 +527,25 @@ public static double validateSize(double size) {
}

/** Generates a {@link DoFnInvoker} class for the given {@link DoFnSignature}. */
private static Class<? extends DoFnInvoker<?, ?>> generateInvokerClass(DoFnSignature signature) {
private static Class<? extends DoFnInvoker<?, ?>> generateInvokerClass(
DoFnSignature signature, TypeDescriptor<?> inputType, TypeDescriptor<?> outputType) {
Class<? extends DoFn<?, ?>> fnClass = signature.fnClass();

// Create a unique suffix based on the type descriptors to avoid class name collisions
// when the same DoFn class is used with different generic types.
String typeSuffix =
String.format(
"%s$%08x",
DoFnInvoker.class.getSimpleName(),
(inputType.toString() + "|" + outputType.toString()).hashCode());

final TypeDescription clazzDescription = new TypeDescription.ForLoadedType(fnClass);

DynamicType.Builder<?> builder =
new ByteBuddy()
// Create subclasses inside the target class, to have access to
// private and package-private bits
.with(
StableInvokerNamingStrategy.forDoFnClass(fnClass)
.withSuffix(DoFnInvoker.class.getSimpleName()))
.with(StableInvokerNamingStrategy.forDoFnClass(fnClass).withSuffix(typeSuffix))

// class <invoker class> extends DoFnInvokerBase {
.subclass(DoFnInvokerBase.class, ConstructorStrategy.Default.NO_CONSTRUCTORS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertThrows;
Expand Down Expand Up @@ -77,6 +78,8 @@
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.UserCodeException;
import org.apache.beam.sdk.values.OutputBuilder;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.sdk.values.WindowedValues;
import org.joda.time.Instant;
import org.junit.Before;
Expand Down Expand Up @@ -1382,11 +1385,18 @@ public void process() {}
@Test
public void testStableName() {
DoFnInvoker<Void, Void> invoker = DoFnInvokers.invokerFor(new StableNameTestDoFn());
// The invoker class name includes a hash of the type descriptors to support
// different generic instantiations of the same DoFn class.
// Format: <DoFn class name>$<DoFnInvoker>$<type hash>
TypeDescriptor<Void> voidType = new StableNameTestDoFn().getInputTypeDescriptor();
String expectedTypeSuffix =
String.format(
"%s$%08x",
DoFnInvoker.class.getSimpleName(),
(voidType.toString() + "|" + voidType.toString()).hashCode());
assertThat(
invoker.getClass().getName(),
equalTo(
String.format(
"%s$%s", StableNameTestDoFn.class.getName(), DoFnInvoker.class.getSimpleName())));
equalTo(String.format("%s$%s", StableNameTestDoFn.class.getName(), expectedTypeSuffix)));
}

@Test
Expand All @@ -1406,4 +1416,45 @@ public void processElement(BundleFinalizer bundleFinalizer) {

verify(mockBundleFinalizer).afterBundleCommit(eq(Instant.ofEpochSecond(42L)), eq(null));
}

@Test
public void testCacheKeyCollisionProof() throws Exception {
class DynamicTypeDoFn<T> extends DoFn<T, T> {
private final TypeDescriptor<T> typeDescriptor;

DynamicTypeDoFn(TypeDescriptor<T> typeDescriptor) {
this.typeDescriptor = typeDescriptor;
}

@ProcessElement
public void processElement(@Element T element, OutputReceiver<T> out) {
out.output(element);
}

// Key point: force returning our specified type instead of relying on class signature
@Override
public TypeDescriptor<T> getInputTypeDescriptor() {
return typeDescriptor;
}

@Override
public TypeDescriptor<T> getOutputTypeDescriptor() {
return typeDescriptor;
}
}

DoFn<String, String> stringFn = new DynamicTypeDoFn<>(TypeDescriptors.strings());
DoFn<Integer, Integer> intFn = new DynamicTypeDoFn<>(TypeDescriptors.integers());

DoFnInvoker<String, String> stringInvoker = DoFnInvokers.invokerFor(stringFn);
DoFnInvoker<Integer, Integer> intInvoker = DoFnInvokers.invokerFor(intFn);

System.out.println("String Invoker: " + stringInvoker.getClass().getName());
System.out.println("Integer Invoker: " + intInvoker.getClass().getName());

assertNotSame(
"Critical bug: Beam returned the same cached class for different generic types.",
stringInvoker.getClass(),
intInvoker.getClass());
}
}
Loading