Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions policy/src/main/java/dev/cel/policy/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ java_library(
],
deps = [
":compiler",
"//validator:ast_validator",
"@maven//:com_google_errorprone_error_prone_annotations",
],
)
Expand Down
19 changes: 19 additions & 0 deletions policy/src/main/java/dev/cel/policy/CelPolicyCompilerBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import com.google.errorprone.annotations.CanIgnoreReturnValue;
import com.google.errorprone.annotations.CheckReturnValue;
import dev.cel.validator.CelAstValidator;

/** Interface for building an instance of {@link CelPolicyCompiler} */
public interface CelPolicyCompilerBuilder {
Expand All @@ -38,6 +39,24 @@ public interface CelPolicyCompilerBuilder {
@CanIgnoreReturnValue
CelPolicyCompilerBuilder setAstDepthLimit(int iterationLimit);

/**
* Adds one or more {@link CelAstValidators} to the compiler. These apply per CEL expression in
* the policy.
*/
@CanIgnoreReturnValue
CelPolicyCompilerBuilder addValidators(Iterable<? extends CelAstValidator> validators);

/**
* Adds one or more {@link CelAstValidators} to the compiler. These apply per CEL expression in
* the policy.
*/
@CanIgnoreReturnValue
CelPolicyCompilerBuilder addValidators(CelAstValidator... validators);

/** Removes any custom validators from the compiler builder. */
@CanIgnoreReturnValue
CelPolicyCompilerBuilder clearValidators();

@CheckReturnValue
CelPolicyCompiler build();
}
49 changes: 47 additions & 2 deletions policy/src/main/java/dev/cel/policy/CelPolicyCompilerImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ final class CelPolicyCompilerImpl implements CelPolicyCompiler {
private final String variablesPrefix;
private final int iterationLimit;
private final Optional<CelAstValidator> astDepthValidator;
private final Optional<CelValidator> validator;

@Override
public CelCompiledRule compileRule(CelPolicy policy) throws CelPolicyValidationException {
Expand Down Expand Up @@ -194,6 +195,10 @@ private CelCompiledRule compileRuleImpl(
CelType outputType = SimpleType.DYN;
try {
varAst = localCel.compile(expression.value()).getAst();
if (this.validator.isPresent()) {
CelValidationResult result = this.validator.get().validate(varAst);
varAst = result.getAst();
}
outputType = varAst.getResultType();
} catch (CelValidationException e) {
compilerContext.addIssue(expression.id(), e.getErrors());
Expand All @@ -212,6 +217,10 @@ private CelCompiledRule compileRuleImpl(
CelAbstractSyntaxTree conditionAst;
try {
conditionAst = localCel.compile(match.condition().value()).getAst();
if (this.validator.isPresent()) {
CelValidationResult result = this.validator.get().validate(conditionAst);
conditionAst = result.getAst();
}
if (!conditionAst.getResultType().equals(SimpleType.BOOL)) {
compilerContext.addIssue(
match.condition().id(),
Expand All @@ -229,6 +238,10 @@ private CelCompiledRule compileRuleImpl(
ValueString output = match.result().output();
try {
outputAst = localCel.compile(output.value()).getAst();
if (this.validator.isPresent()) {
CelValidationResult result = this.validator.get().validate(outputAst);
outputAst = result.getAst();
}
} catch (CelValidationException e) {
compilerContext.addIssue(output.id(), e.getErrors());
continue;
Expand Down Expand Up @@ -340,10 +353,12 @@ static final class Builder implements CelPolicyCompilerBuilder {
private String variablesPrefix;
private int iterationLimit;
private Optional<CelAstValidator> astDepthLimitValidator;
private ArrayList<CelAstValidator> validators;

private Builder(Cel cel) {
this.cel = cel;
this.astDepthLimitValidator = Optional.of(AstDepthLimitValidator.DEFAULT);
this.validators = new ArrayList<>();
}

@Override
Expand All @@ -360,6 +375,26 @@ public Builder setIterationLimit(int iterationLimit) {
return this;
}

@Override
@CanIgnoreReturnValue
public Builder addValidators(Iterable<? extends CelAstValidator> validators) {
validators.forEach(this.validators::add);
return this;
}

@Override
@CanIgnoreReturnValue
public Builder addValidators(CelAstValidator... validators) {
return addValidators(Arrays.asList(validators));
}

@Override
@CanIgnoreReturnValue
public Builder clearValidators() {
this.validators.clear();
return this;
}

@Override
@CanIgnoreReturnValue
public CelPolicyCompilerBuilder setAstDepthLimit(int astDepthLimit) {
Expand All @@ -374,7 +409,7 @@ public CelPolicyCompilerBuilder setAstDepthLimit(int astDepthLimit) {
@Override
public CelPolicyCompiler build() {
return new CelPolicyCompilerImpl(
cel, this.variablesPrefix, this.iterationLimit, astDepthLimitValidator);
cel, this.variablesPrefix, this.iterationLimit, astDepthLimitValidator, validators);
}
}

Expand All @@ -388,10 +423,20 @@ private CelPolicyCompilerImpl(
Cel cel,
String variablesPrefix,
int iterationLimit,
Optional<CelAstValidator> astDepthValidator) {
Optional<CelAstValidator> astDepthValidator,
List<CelAstValidator> additionalValidators) {
this.cel = checkNotNull(cel);
this.variablesPrefix = checkNotNull(variablesPrefix);
this.iterationLimit = iterationLimit;
this.astDepthValidator = astDepthValidator;
if (additionalValidators.isEmpty()) {
this.validator = Optional.empty();
} else {
this.validator =
Optional.of(
CelValidatorFactory.standardCelValidatorBuilder(cel)
.addAstValidators(additionalValidators)
.build());
}
}
}
5 changes: 4 additions & 1 deletion policy/src/test/java/dev/cel/policy/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ java_library(
"//bundle:environment_yaml_parser",
"//common:cel_ast",
"//common:options",
"//common/ast",
"//common/formats:value_string",
"//common/internal",
"//common/navigation",
"//common/navigation:common",
"//common/resources/testdata/proto3:standalone_global_enum_java_proto",
"//common/types",
"//compiler",
Expand All @@ -35,8 +38,8 @@ java_library(
"//policy:validation_exception",
"//runtime",
"//runtime:function_binding",
"//runtime:late_function_binding",
"//testing/protos:single_file_java_proto",
"//validator:ast_validator",
"@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto",
"@maven//:com_google_guava_guava",
"@maven//:com_google_testparameterinjector_test_parameter_injector",
Expand Down
106 changes: 106 additions & 0 deletions policy/src/test/java/dev/cel/policy/CelPolicyCompilerImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
import dev.cel.bundle.CelFactory;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelOptions;
import dev.cel.common.ast.CelConstant;
import dev.cel.common.ast.CelExpr.ExprKind;
import dev.cel.common.navigation.CelNavigableAst;
import dev.cel.common.navigation.CelNavigableExpr;
import dev.cel.common.navigation.TraversalOrder;
import dev.cel.common.types.OptionalType;
import dev.cel.common.types.SimpleType;
import dev.cel.expr.conformance.proto3.TestAllTypes;
Expand All @@ -47,6 +52,8 @@
import dev.cel.runtime.CelLateFunctionBindings;
import dev.cel.testing.testdata.SingleFileProto.SingleFile;
import dev.cel.testing.testdata.proto3.StandaloneGlobalEnum;
import dev.cel.validator.CelAstValidator;
import dev.cel.validator.CelAstValidator.IssuesFactory;
import java.io.IOException;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -265,6 +272,105 @@ public void evaluateYamlPolicy_nestedRuleProducesOptionalOutput() throws Excepti
assertThat(evalResult).hasValue(Optional.of(true));
}

static final class NoFooLiteralsValidator implements CelAstValidator {
private static boolean isFooLiteral(CelNavigableExpr node) {
return node.getKind().equals(ExprKind.Kind.CONSTANT)
&& node.expr().constant().getKind().equals(CelConstant.Kind.STRING_VALUE)
&& node.expr().constant().stringValue().equals("foo");
}

@Override
public void validate(CelNavigableAst navigableAst, Cel cel, IssuesFactory issuesFactory) {
navigableAst
.getRoot()
.descendants(TraversalOrder.POST_ORDER)
.filter(NoFooLiteralsValidator::isFooLiteral)
.forEach(node -> issuesFactory.addError(node.id(), "'foo' is a forbidden literal"));
}
}

@Test
public void evaluateYamlPolicy_validatorReportsErrors() throws Exception {
Cel cel = newCel();
String policySource =
"name: nested_rule_with_forbidden_literal\n"
+ "rule:\n"
+ " variables:\n"
+ " - name: 'foo'\n"
+ " expression: \"(true) ? 'bar' : 'foo'\"\n"
+ " match:\n"
+ " - condition: |\n"
+ " variables.foo in ['foo', 'bar', 'foo']\n"
+ " output: >\n"
+ " 'foo' == variables.foo\n";
CelPolicy policy = POLICY_PARSER.parse(policySource);
CelPolicyValidationException e =
assertThrows(
CelPolicyValidationException.class,
() ->
CelPolicyCompilerFactory.newPolicyCompiler(cel)
.addValidators(new NoFooLiteralsValidator())
.build()
.compile(policy));

assertThat(e)
.hasMessageThat()
.contains(
"ERROR: <input>:5:37: 'foo' is a forbidden literal\n"
+ " | expression: \"(true) ? 'bar' : 'foo'\"\n"
+ " | ....................................^");
assertThat(e)
.hasMessageThat()
.contains(
"ERROR: <input>:8:27: 'foo' is a forbidden literal\n"
+ " | variables.foo in ['foo', 'bar', 'foo']\n"
+ " | ..........................^");
assertThat(e)
.hasMessageThat()
.contains(
"ERROR: <input>:8:41: 'foo' is a forbidden literal\n"
+ " | variables.foo in ['foo', 'bar', 'foo']\n"
+ " | ........................................^");
}

// If the condition fails to validate, then the compiler doesn't attempt to compile or validate
// the output, so second test case for coverage.
@Test
public void evaluateYamlPolicy_validatorReportsOutput() throws Exception {
Cel cel = newCel();
String policySource =
"name: nested_rule_with_forbidden_literal\n"
+ "rule:\n"
+ " variables:\n"
+ " - name: 'foo'\n"
+ " expression: \"(true) ? 'bar' : 'foo'\"\n"
+ " match:\n"
+ " - output: >\n"
+ " 'foo' == variables.foo\n";
CelPolicy policy = POLICY_PARSER.parse(policySource);
CelPolicyValidationException e =
assertThrows(
CelPolicyValidationException.class,
() ->
CelPolicyCompilerFactory.newPolicyCompiler(cel)
.addValidators(new NoFooLiteralsValidator())
.build()
.compile(policy));

assertThat(e)
.hasMessageThat()
.contains(
"ERROR: <input>:5:37: 'foo' is a forbidden literal\n"
+ " | expression: \"(true) ? 'bar' : 'foo'\"\n"
+ " | ....................................^");
assertThat(e)
.hasMessageThat()
.contains(
"ERROR: <input>:8:9: 'foo' is a forbidden literal\n"
+ " | 'foo' == variables.foo\n"
+ " | ........^");
}

@Test
public void evaluateYamlPolicy_lateBoundFunction() throws Exception {
String configSource =
Expand Down