Skip to content

Commit

Permalink
Only use ExprEval in ConstantExpr if its known that it will be safe (#…
Browse files Browse the repository at this point in the history
…15694)

* `Expr#singleThreaded` which creates a singleThreaded version of the actual expression (caching ExprEval is allowed)
* `Expr#makeSingleThreaded` to make a whole subtree of expressions 'singleThreaded' - uses `Shuttle` to create the new expression tree
* `ConstantExpr#singleThreaded` creates a specialized `ConstantExpr` which does cache the `ExprEval`
* some `@Immutable` annotations were added to make it more likely to notice that there might be something off if a similar change will be made around here for some reason
  • Loading branch information
kgyrtkirk authored Mar 5, 2024
1 parent e13ed7b commit 27d7c30
Show file tree
Hide file tree
Showing 13 changed files with 221 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.common.collect.ImmutableList;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.io.Closer;
Expand Down Expand Up @@ -517,6 +518,51 @@ public void caseSearched2(Blackhole blackhole)
blackhole.consume(results);
}


@Benchmark
public void caseSearched100(Blackhole blackhole)
{

StringBuilder caseBranches = new StringBuilder();
for (int i = 0; i < 100; i++) {
caseBranches.append(
StringUtils.format(
"n == %d, %d,",
i,
i * i
)
);
}

final Sequence<Cursor> cursors = new QueryableIndexStorageAdapter(index).makeCursors(
null,
index.getDataInterval(),
VirtualColumns.create(
ImmutableList.of(
new ExpressionVirtualColumn(
"v",
"case_searched(s == 'asd' || isnull(s) || n == 1, 1, " + caseBranches + " 3)",
ColumnType.LONG,
TestExprMacroTable.INSTANCE
)
)
),
Granularities.ALL,
false,
null
);

final List<?> results = cursors
.map(cursor -> {
final ColumnValueSelector selector = cursor.getColumnSelectorFactory().makeColumnValueSelector("v");
consumeLong(cursor, selector, blackhole);
return null;
})
.toList();

blackhole.consume(results);
}

@Benchmark
public void caseSearchedWithLookup(Blackhole blackhole)
{
Expand Down
117 changes: 86 additions & 31 deletions processing/src/main/java/org/apache/druid/math/expr/ConstantExpr.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.druid.math.expr;

import com.google.common.base.Preconditions;
import com.google.errorprone.annotations.Immutable;
import org.apache.commons.lang.StringEscapeUtils;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.java.util.common.IAE;
Expand All @@ -29,6 +30,8 @@
import org.apache.druid.segment.column.TypeStrategy;

import javax.annotation.Nullable;
import javax.annotation.concurrent.NotThreadSafe;

import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.util.Arrays;
Expand All @@ -39,9 +42,12 @@
* {@link Expr.ObjectBinding}. {@link ConstantExpr} are terminal nodes of an expression tree, and have no children
* {@link Expr}.
*/
abstract class ConstantExpr<T> implements Expr
@Immutable
abstract class ConstantExpr<T> implements Expr, Expr.SingleThreadSpecializable
{
final ExpressionType outputType;

@SuppressWarnings("Immutable")
@Nullable
final T value;

Expand All @@ -53,38 +59,38 @@ protected ConstantExpr(ExpressionType outputType, @Nullable T value)

@Nullable
@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
public final ExpressionType getOutputType(InputBindingInspector inspector)
{
// null isn't really a type, so don't claim anything
return value == null ? null : outputType;
}

@Override
public boolean isLiteral()
public final boolean isLiteral()
{
return true;
}

@Override
public boolean isNullLiteral()
public final boolean isNullLiteral()
{
return value == null;
}

@Override
public Object getLiteralValue()
public final Object getLiteralValue()
{
return value;
}

@Override
public Expr visit(Shuttle shuttle)
public final Expr visit(Shuttle shuttle)
{
return shuttle.visit(this);
}

@Override
public BindingAnalysis analyzeInputs()
public final BindingAnalysis analyzeInputs()
{
return BindingAnalysis.EMTPY;
}
Expand All @@ -100,6 +106,67 @@ public String stringify()
{
return toString();
}

@Override
public final ExprEval eval(ObjectBinding bindings)
{
return realEval();
}

protected abstract ExprEval<T> realEval();


@Override
public Expr toSingleThreaded()
{
return new ExprEvalBasedConstantExpr<T>(realEval());
}

/**
* Constant expression based on a concreate ExprEval.
*
* Not multi-thread safe.
*/
@NotThreadSafe
@SuppressWarnings("Immutable")
private static final class ExprEvalBasedConstantExpr<T> extends ConstantExpr<T>
{
private final ExprEval<T> eval;

private ExprEvalBasedConstantExpr(ExprEval<T> eval)
{
super(eval.type(), eval.value);
this.eval = eval;
}

@Override
protected ExprEval<T> realEval()
{
return eval;
}

@Override
public int hashCode()
{
return Objects.hash(eval);
}

@Override
public boolean equals(Object obj)
{
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
ExprEvalBasedConstantExpr<?> other = (ExprEvalBasedConstantExpr<?>) obj;
return Objects.equals(eval, other.eval);
}
}
}

/**
Expand All @@ -121,7 +188,7 @@ public String toString()
}

@Override
public ExprEval eval(ObjectBinding bindings)
protected ExprEval realEval()
{
// Eval succeeds if the BigInteger is in long range.
// Callers that need to process out-of-long-range values, like UnaryMinusExpr, must use getLiteralValue().
Expand Down Expand Up @@ -157,12 +224,9 @@ public int hashCode()

class LongExpr extends ConstantExpr<Long>
{
private final ExprEval expr;

LongExpr(Long value)
{
super(ExpressionType.LONG, Preconditions.checkNotNull(value, "value"));
expr = ExprEval.ofLong(value);
}

@Override
Expand All @@ -172,9 +236,9 @@ public String toString()
}

@Override
public ExprEval eval(ObjectBinding bindings)
protected ExprEval realEval()
{
return expr;
return ExprEval.ofLong(value);
}

@Override
Expand Down Expand Up @@ -211,7 +275,7 @@ class NullLongExpr extends ConstantExpr<Long>
}

@Override
public ExprEval eval(ObjectBinding bindings)
protected ExprEval realEval()
{
return ExprEval.ofLong(null);
}
Expand Down Expand Up @@ -243,12 +307,9 @@ public String toString()

class DoubleExpr extends ConstantExpr<Double>
{
private final ExprEval expr;

DoubleExpr(Double value)
{
super(ExpressionType.DOUBLE, Preconditions.checkNotNull(value, "value"));
expr = ExprEval.ofDouble(value);
}

@Override
Expand All @@ -258,9 +319,9 @@ public String toString()
}

@Override
public ExprEval eval(ObjectBinding bindings)
protected ExprEval realEval()
{
return expr;
return ExprEval.ofDouble(value);
}

@Override
Expand Down Expand Up @@ -297,7 +358,7 @@ class NullDoubleExpr extends ConstantExpr<Double>
}

@Override
public ExprEval eval(ObjectBinding bindings)
protected ExprEval realEval()
{
return ExprEval.ofDouble(null);
}
Expand Down Expand Up @@ -329,12 +390,9 @@ public String toString()

class StringExpr extends ConstantExpr<String>
{
private final ExprEval expr;

StringExpr(@Nullable String value)
{
super(ExpressionType.STRING, NullHandling.emptyToNullIfNeeded(value));
expr = ExprEval.of(value);
}

@Override
Expand All @@ -344,9 +402,9 @@ public String toString()
}

@Override
public ExprEval eval(ObjectBinding bindings)
protected ExprEval realEval()
{
return expr;
return ExprEval.of(value);
}

@Override
Expand Down Expand Up @@ -391,7 +449,7 @@ public ArrayExpr(ExpressionType outputType, @Nullable Object[] value)
}

@Override
public ExprEval eval(ObjectBinding bindings)
protected ExprEval realEval()
{
return ExprEval.ofArray(outputType, value);
}
Expand Down Expand Up @@ -473,18 +531,15 @@ public String toString()

class ComplexExpr extends ConstantExpr<Object>
{
private final ExprEval expr;

protected ComplexExpr(ExpressionType outputType, @Nullable Object value)
{
super(outputType, value);
expr = ExprEval.ofComplex(outputType, value);
}

@Override
public ExprEval eval(ObjectBinding bindings)
protected ExprEval realEval()
{
return expr;
return ExprEval.ofComplex(outputType, value);
}

@Override
Expand Down
32 changes: 32 additions & 0 deletions processing/src/main/java/org/apache/druid/math/expr/Expr.java
Original file line number Diff line number Diff line change
Expand Up @@ -763,4 +763,36 @@ private static Set<String> map(
return results;
}
}


/**
* Returns the single-threaded version of the given expression tree.
*
* Nested expressions in the subtree are also optimized.
* Individual {@link Expr}-s which have a singleThreaded implementation via {@link SingleThreadSpecializable} are substituted.
*/
static Expr singleThreaded(Expr expr)
{
return expr.visit(
node -> {
if (node instanceof SingleThreadSpecializable) {
SingleThreadSpecializable canBeSingleThreaded = (SingleThreadSpecializable) node;
return canBeSingleThreaded.toSingleThreaded();
} else {
return node;
}
}
);
}

/**
* Implementing this interface allows to provide a non-threadsafe {@link Expr} implementation.
*/
interface SingleThreadSpecializable
{
/**
* Non-threadsafe version of this expression.
*/
Expr toSingleThreaded();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

package org.apache.druid.math.expr;

import com.google.errorprone.annotations.Immutable;
import org.apache.druid.segment.column.TypeDescriptor;

/**
* Base 'value' types of Druid expression language, all {@link Expr} must evaluate to one of these types.
*/
@Immutable
public enum ExprType implements TypeDescriptor
{
DOUBLE,
Expand Down
Loading

0 comments on commit 27d7c30

Please sign in to comment.