Skip to content

Commit

Permalink
cache compiled datasets (elastic#11482)
Browse files Browse the repository at this point in the history
  • Loading branch information
colinsurprenant authored Jan 13, 2020
1 parent 1bc0aea commit 4b29112
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import co.elastic.logstash.api.Codec;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.jruby.RubyArray;
import org.jruby.RubyHash;
import org.jruby.javasupport.JavaUtil;
import org.jruby.runtime.builtin.IRubyObject;
Expand Down Expand Up @@ -34,6 +35,8 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand Down Expand Up @@ -78,6 +81,18 @@ public final class CompiledPipeline {
*/
private final RubyIntegration.PluginFactory pluginFactory;

/**
* Per pipeline compiled classes cache shared across threads {@link CompiledExecution}
*/
private final Map<String, Class<? extends Dataset>> datasetClassCache = new ConcurrentHashMap<>(500);

/**
* First, constructor time, compilation of the pipeline that will warm
* the {@link CompiledPipeline#datasetClassCache} in a thread safe way
* before the concurrent per worker threads {@link CompiledExecution} compilations
*/
private final AtomicReference<CompiledExecution> warmedCompiledExecution = new AtomicReference<>();

public CompiledPipeline(
final PipelineIR pipelineIR,
final RubyIntegration.PluginFactory pluginFactory) {
Expand All @@ -96,6 +111,10 @@ public CompiledPipeline(
inputs = setupInputs(cve);
filters = setupFilters(cve);
outputs = setupOutputs(cve);

// invoke a first compilation to warm the class cache which will prevent
// redundant compilations for each subsequent worker {@link CompiledExecution}
warmedCompiledExecution.set(new CompiledPipeline.CompiledExecution());
} catch (Exception e) {
throw new IllegalStateException("Unable to configure plugins: " + e.getMessage());
}
Expand All @@ -119,6 +138,10 @@ public Collection<IRubyObject> inputs() {
* @return Compiled {@link Dataset} representation of the underlying {@link PipelineIR} topology
*/
public Dataset buildExecution() {
CompiledExecution result = warmedCompiledExecution.getAndSet(null);
if (result != null) {
return result.toDataset();
}
return new CompiledPipeline.CompiledExecution().toDataset();
}

Expand Down Expand Up @@ -269,6 +292,17 @@ private boolean isOutput(final Vertex vertex) {
return outputs.containsKey(vertex.getId());
}

/**
* Returns an existing compiled dataset class implementation for the given {@code vertexId},
* or compiles one from the provided {@code computeStepSyntaxElement}.
* @param vertexId a string uniquely identifying a {@link Vertex} within the current pipeline
* @param computeStepSyntaxElement the source from which to compile a dataset class
* @return an implementation of {@link Dataset} for the given vertex
*/
private Class<? extends Dataset> getDatasetClass(final String vertexId, final ComputeStepSyntaxElement<? extends Dataset> computeStepSyntaxElement) {
return datasetClassCache.computeIfAbsent(vertexId, _vid -> computeStepSyntaxElement.compile());
}

/**
* Instances of this class represent a fully compiled pipeline execution. Note that this class
* has a separate lifecycle from {@link CompiledPipeline} because it holds per (worker-thread)
Expand All @@ -279,13 +313,13 @@ private final class CompiledExecution {
/**
* Compiled {@link IfVertex, indexed by their ID as returned by {@link Vertex#getId()}.
*/
private final Map<String, SplitDataset> iffs = new HashMap<>(5);
private final Map<String, SplitDataset> iffs = new HashMap<>(50);

/**
* Cached {@link Dataset} compiled from {@link PluginVertex} indexed by their ID as returned
* by {@link Vertex#getId()} to avoid duplicate computations.
*/
private final Map<String, Dataset> plugins = new HashMap<>(5);
private final Map<String, Dataset> plugins = new HashMap<>(50);

private final Dataset compiled;

Expand All @@ -308,11 +342,37 @@ private Dataset compile() {
if (outputNodes.isEmpty()) {
return Dataset.IDENTITY;
} else {
return DatasetCompiler.terminalDataset(outputNodes.stream().map(
return terminalDataset(outputNodes.stream().map(
leaf -> outputDataset(leaf, flatten(Collections.emptyList(), leaf))
).collect(Collectors.toList()));
}
}
/**
* <p>Builds a terminal {@link Dataset} from the given parent {@link Dataset}s.</p>
* <p>If the given set of parent {@link Dataset} is empty the sum is defined as the
* trivial dataset that does not invoke any computation whatsoever.</p>
* {@link Dataset#compute(RubyArray, boolean, boolean)} is always
* {@link Collections#emptyList()}.
* @param parents Parent {@link Dataset} to sum and terminate
* @return Dataset representing the sum of given parent {@link Dataset}
*/
public Dataset terminalDataset(final Collection<Dataset> parents) {
final int count = parents.size();
final Dataset result;
if (count > 1) {
ComputeStepSyntaxElement<Dataset> prepared = DatasetCompiler.terminalDataset(parents);
result = prepared.instantiate(prepared.compile());
} else if (count == 1) {
// No need for a terminal dataset here, if there is only a single parent node we can
// call it directly.
result = parents.iterator().next();
} else {
throw new IllegalArgumentException(
"Cannot create Terminal Dataset for an empty number of parent datasets"
);
}
return result;
}

/**
* Build a {@link Dataset} representing the {@link JrubyEventExtLibrary.RubyEvent}s after
Expand All @@ -325,12 +385,14 @@ private Dataset filterDataset(final Vertex vertex, final Collection<Dataset> dat
final String vertexId = vertex.getId();

if (!plugins.containsKey(vertexId)) {
final ComputeStepSyntaxElement<Dataset> prepared =
DatasetCompiler.filterDataset(flatten(datasets, vertex),
filters.get(vertexId));
final ComputeStepSyntaxElement<Dataset> prepared = DatasetCompiler.filterDataset(
flatten(datasets, vertex),
filters.get(vertexId));
final Class<? extends Dataset> clazz = getDatasetClass(vertexId, prepared);

LOGGER.debug("Compiled filter\n {} \n into \n {}", vertex, prepared);

plugins.put(vertexId, prepared.instantiate());
plugins.put(vertexId, prepared.instantiate(clazz));
}

return plugins.get(vertexId);
Expand All @@ -347,13 +409,16 @@ private Dataset outputDataset(final Vertex vertex, final Collection<Dataset> dat
final String vertexId = vertex.getId();

if (!plugins.containsKey(vertexId)) {
final ComputeStepSyntaxElement<Dataset> prepared =
DatasetCompiler.outputDataset(flatten(datasets, vertex),
outputs.get(vertexId),
outputs.size() == 1);
final ComputeStepSyntaxElement<Dataset> prepared = DatasetCompiler.outputDataset(
flatten(datasets, vertex),
outputs.get(vertexId),
outputs.size() == 1);
final Class<? extends Dataset> clazz = getDatasetClass(vertexId, prepared);

LOGGER.debug("Compiled output\n {} \n into \n {}", vertex, prepared);
plugins.put(vertexId, prepared.instantiate());
}

plugins.put(vertexId, prepared.instantiate(clazz));
}

return plugins.get(vertexId);
}
Expand All @@ -368,24 +433,25 @@ private Dataset outputDataset(final Vertex vertex, final Collection<Dataset> dat
*/
private SplitDataset split(final Collection<Dataset> datasets,
final EventCondition condition, final Vertex vertex) {
final String key = vertex.getId();
SplitDataset conditional = iffs.get(key);
final String vertexId = vertex.getId();
SplitDataset conditional = iffs.get(vertexId);

if (conditional == null) {
final Collection<Dataset> dependencies = flatten(datasets, vertex);
conditional = iffs.get(key);
conditional = iffs.get(vertexId);
// Check that compiling the dependencies did not already instantiate the conditional
// by requiring its else branch.
if (conditional == null) {
final ComputeStepSyntaxElement<SplitDataset> prepared =
DatasetCompiler.splitDataset(dependencies, condition);
LOGGER.debug(
"Compiled conditional\n {} \n into \n {}", vertex, prepared
);
conditional = prepared.instantiate();
iffs.put(key, conditional);
}
final ComputeStepSyntaxElement<SplitDataset> prepared = DatasetCompiler.splitDataset(dependencies, condition);
final Class<? extends Dataset> clazz = getDatasetClass(vertexId, prepared);

LOGGER.debug("Compiled conditional\n {} \n into \n {}", vertex, prepared);

conditional = prepared.instantiate(clazz);
iffs.put(vertexId, conditional);
}
}

return conditional;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
Expand All @@ -32,10 +33,9 @@ public final class ComputeStepSyntaxElement<T extends Dataset> {
private static final ISimpleCompiler COMPILER = new SimpleCompiler();

/**
* Cache of runtime compiled classes to prevent duplicate classes being compiled.
* Sequential counter to generate the class name
*/
private static final Map<ComputeStepSyntaxElement<?>, Class<? extends Dataset>> CLASS_CACHE
= new HashMap<>();
private static final AtomicLong classSeqCount = new AtomicLong();

/**
* Pattern to remove redundant {@code ;} from formatted code since {@link Formatter} does not
Expand All @@ -49,6 +49,8 @@ public final class ComputeStepSyntaxElement<T extends Dataset> {

private final Class<T> type;

private final long classSeq;

public static <T extends Dataset> ComputeStepSyntaxElement<T> create(
final Iterable<MethodSyntaxElement> methods, final ClassFields fields,
final Class<T> interfce) {
Expand All @@ -60,37 +62,41 @@ private ComputeStepSyntaxElement(final Iterable<MethodSyntaxElement> methods,
this.methods = methods;
this.fields = fields;
type = interfce;
classSeq = classSeqCount.incrementAndGet();
}

@SuppressWarnings("unchecked")
public T instantiate() {
// We need to globally synchronize to avoid concurrency issues with the internal class
// loader and the CLASS_CACHE
public T instantiate(Class<? extends Dataset> clazz) {
try {
return (T) clazz.<T>getConstructor(Map.class).newInstance(ctorArguments());
} catch (final NoSuchMethodException | InvocationTargetException | InstantiationException | IllegalAccessException ex) {
throw new IllegalStateException(ex);
}
}

@SuppressWarnings("unchecked")
public Class<? extends Dataset> compile() {
// We need to globally synchronize to avoid concurrency issues with the internal class loader
// Per https://github.com/elastic/logstash/pull/11482 we should review this lock.
synchronized (COMPILER) {
try {
final Class<? extends Dataset> clazz;
if (CLASS_CACHE.containsKey(this)) {
clazz = CLASS_CACHE.get(this);
final String name = String.format("CompiledDataset%d", classSeq);
final String code = generateCode(name);
if (SOURCE_DIR != null) {
final Path sourceFile = SOURCE_DIR.resolve(String.format("%s.java", name));
Files.write(sourceFile, code.getBytes(StandardCharsets.UTF_8));
COMPILER.cookFile(sourceFile.toFile());
} else {
final String name = String.format("CompiledDataset%d", CLASS_CACHE.size());
final String code = generateCode(name);
if (SOURCE_DIR != null) {
final Path sourceFile = SOURCE_DIR.resolve(String.format("%s.java", name));
Files.write(sourceFile, code.getBytes(StandardCharsets.UTF_8));
COMPILER.cookFile(sourceFile.toFile());
} else {
COMPILER.cook(code);
}
COMPILER.setParentClassLoader(COMPILER.getClassLoader());
clazz = (Class<T>) COMPILER.getClassLoader().loadClass(
String.format("org.logstash.generated.%s", name)
);
CLASS_CACHE.put(this, clazz);
COMPILER.cook(code);
}
return (T) clazz.<T>getConstructor(Map.class).newInstance(ctorArguments());
} catch (final CompileException | ClassNotFoundException | IOException
| NoSuchMethodException | InvocationTargetException | InstantiationException
| IllegalAccessException ex) {
COMPILER.setParentClassLoader(COMPILER.getClassLoader());
clazz = (Class<? extends Dataset>)COMPILER.getClassLoader().loadClass(
String.format("org.logstash.generated.%s", name)
);

return clazz;
} catch (final CompileException | ClassNotFoundException | IOException ex) {
throw new IllegalStateException(ex);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ public static ComputeStepSyntaxElement<Dataset> filterDataset(final Collection<D
final ValueSyntaxElement outputBuffer = fields.add(new ArrayList<>());
final Closure clear = Closure.wrap();
final Closure compute;

if (parents.isEmpty()) {
compute = filterBody(outputBuffer, BATCH_ARG, fields, plugin);
} else {
Expand All @@ -116,29 +117,16 @@ public static ComputeStepSyntaxElement<Dataset> filterDataset(final Collection<D
* @param parents Parent {@link Dataset} to sum and terminate
* @return Dataset representing the sum of given parent {@link Dataset}
*/
public static Dataset terminalDataset(final Collection<Dataset> parents) {
final int count = parents.size();
final Dataset result;
if (count > 1) {
final ClassFields fields = new ClassFields();
final Collection<ValueSyntaxElement> parentFields =
parents.stream().map(fields::add).collect(Collectors.toList());
result = compileOutput(
Closure.wrap(
parentFields.stream().map(DatasetCompiler::computeDataset)
.toArray(MethodLevelSyntaxElement[]::new)
).add(clearSyntax(parentFields)), Closure.EMPTY, fields
).instantiate();
} else if (count == 1) {
// No need for a terminal dataset here, if there is only a single parent node we can
// call it directly.
result = parents.iterator().next();
} else {
throw new IllegalArgumentException(
"Cannot create Terminal Dataset for an empty number of parent datasets"
);
}
return result;
public static ComputeStepSyntaxElement<Dataset> terminalDataset(final Collection<Dataset> parents) {
final ClassFields fields = new ClassFields();
final Collection<ValueSyntaxElement> parentFields =
parents.stream().map(fields::add).collect(Collectors.toList());
return compileOutput(
Closure.wrap(
parentFields.stream().map(DatasetCompiler::computeDataset)
.toArray(MethodLevelSyntaxElement[]::new)
).add(clearSyntax(parentFields)), Closure.EMPTY, fields
);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,24 @@ public final class DatasetCompilerTest {
*/
@Test
public void compilesOutputDataset() {
final ComputeStepSyntaxElement<Dataset> prepared = DatasetCompiler.outputDataset(
Collections.emptyList(),
PipelineTestUtil.buildOutput(events -> {}),
true
);
assertThat(
DatasetCompiler.outputDataset(
Collections.emptyList(),
PipelineTestUtil.buildOutput(events -> {}),
true
).instantiate().compute(RubyUtil.RUBY.newArray(), false, false),
prepared.instantiate(prepared.compile()).compute(RubyUtil.RUBY.newArray(), false, false),
nullValue()
);
}

@Test
public void compilesSplitDataset() {
final FieldReference key = FieldReference.from("foo");
final SplitDataset left = DatasetCompiler.splitDataset(
final ComputeStepSyntaxElement<SplitDataset> prepared = DatasetCompiler.splitDataset(
Collections.emptyList(), event -> event.getEvent().includes(key)
).instantiate();
);
final SplitDataset left = prepared.instantiate(prepared.compile());
final Event trueEvent = new Event();
trueEvent.setField(key, "val");
final JrubyEventExtLibrary.RubyEvent falseEvent =
Expand Down

0 comments on commit 4b29112

Please sign in to comment.