Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,28 @@ public Object createMessage(String messageName, Map<String, Object> values) {
}

@Override
public Object selectField(Object message, String fieldName) {
SelectableValue<CelValue> selectableValue = getSelectableValueOrThrow(message, fieldName);
public Object selectField(String typeName, Object message, String fieldName) {
SelectableValue<CelValue> selectableValue =
getSelectableValueOrThrow(typeName, message, fieldName);

return unwrapCelValue(selectableValue.select(StringValue.create(fieldName)));
}

@Override
public Object hasField(Object message, String fieldName) {
SelectableValue<CelValue> selectableValue = getSelectableValueOrThrow(message, fieldName);
public Object hasField(String messageName, Object message, String fieldName) {
SelectableValue<CelValue> selectableValue =
getSelectableValueOrThrow(messageName, message, fieldName);

return selectableValue.find(StringValue.create(fieldName)).isPresent();
}

@SuppressWarnings("unchecked")
private SelectableValue<CelValue> getSelectableValueOrThrow(Object obj, String fieldName) {
private SelectableValue<CelValue> getSelectableValueOrThrow(
String typeName, Object obj, String fieldName) {
CelValue convertedCelValue;
if ((obj instanceof MessageLite)) {
// TODO: Pass in typeName for lite messages
convertedCelValue = protoCelValueConverter.fromProtoMessageToCelValue("", (MessageLite) obj);
convertedCelValue =
protoCelValueConverter.fromProtoMessageToCelValue(typeName, (MessageLite) obj);
} else {
convertedCelValue = protoCelValueConverter.fromJavaObjectToCelValue(obj);
}
Expand All @@ -83,7 +86,7 @@ private SelectableValue<CelValue> getSelectableValueOrThrow(Object obj, String f
}

@Override
public Object adapt(Object message) {
public Object adapt(String messageName, Object message) {
if (message instanceof CelUnknownSet) {
return message; // CelUnknownSet is handled specially for iterative evaluation. No need to
// adapt to CelValue.
Expand All @@ -94,9 +97,8 @@ public Object adapt(Object message) {
}

if (message instanceof MessageLite) {
// TODO: Pass in typeName for lite messages
return unwrapCelValue(
protoCelValueConverter.fromProtoMessageToCelValue("", (MessageLite) message));
protoCelValueConverter.fromProtoMessageToCelValue(messageName, (MessageLite) message));
} else {
return unwrapCelValue(protoCelValueConverter.fromJavaObjectToCelValue(message));
}
Expand Down
47 changes: 27 additions & 20 deletions runtime/src/main/java/dev/cel/runtime/DefaultInterpreter.java
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,9 @@ private IntermediateResult evalIdent(ExecutionFrame frame, CelExpr expr)
private IntermediateResult resolveIdent(ExecutionFrame frame, CelExpr expr, String name)
throws CelEvaluationException {
// Check whether the type exists in the type check map as a 'type'.
Optional<CelType> checkedType = ast.getType(expr.id());
if (checkedType.isPresent() && checkedType.get().kind() == CelKind.TYPE) {
TypeType typeValue = typeResolver.adaptType(checkedType.get());
CelType checkedType = getCheckedTypeOrThrow(expr);
if (checkedType.kind() == CelKind.TYPE) {
TypeType typeValue = typeResolver.adaptType(checkedType);
return IntermediateResult.create(typeValue);
}

Expand All @@ -309,7 +309,7 @@ private IntermediateResult resolveIdent(ExecutionFrame frame, CelExpr expr, Stri
}

// Value resolved from Binding, it could be Message, PartialMessage or unbound(null)
value = InterpreterUtil.strict(typeProvider.adapt(value));
value = InterpreterUtil.strict(typeProvider.adapt(checkedType.name(), value));
IntermediateResult result = IntermediateResult.create(rawResult.attribute(), value);

if (isLazyExpression) {
Expand Down Expand Up @@ -357,10 +357,12 @@ private IntermediateResult evalFieldSelect(
return IntermediateResult.create(attribute, operand);
}

CelType operandCheckedType = getCheckedTypeOrThrow(operandExpr);
if (isTestOnly) {
return IntermediateResult.create(attribute, typeProvider.hasField(operand, field));
return IntermediateResult.create(
attribute, typeProvider.hasField(operandCheckedType.name(), operand, field));
}
Object fieldValue = typeProvider.selectField(operand, field);
Object fieldValue = typeProvider.selectField(operandCheckedType.name(), operand, field);

return IntermediateResult.create(
attribute, InterpreterUtil.valueOrUnknown(fieldValue, expr.id()));
Expand Down Expand Up @@ -446,7 +448,8 @@ private IntermediateResult evalCall(ExecutionFrame frame, CelExpr expr, CelCall
try {
Object dispatchResult = overload.getDefinition().apply(argArray);
if (celOptions.unwrapWellKnownTypesOnFunctionDispatch()) {
dispatchResult = typeProvider.adapt(dispatchResult);
CelType checkedType = getCheckedTypeOrThrow(expr);
dispatchResult = typeProvider.adapt(checkedType.name(), dispatchResult);
}
return IntermediateResult.create(attr, dispatchResult);
} catch (CelRuntimeException ce) {
Expand Down Expand Up @@ -665,18 +668,7 @@ private IntermediateResult evalType(ExecutionFrame frame, CelCall callExpr)
return argResult;
}

CelType checkedType =
ast.getType(typeExprArg.id())
.orElseThrow(
() ->
CelEvaluationExceptionBuilder.newBuilder(
"expected a runtime type for '%s' from checked expression, but found"
+ " none.",
argResult.getClass().getSimpleName())
.setErrorCode(CelErrorCode.TYPE_NOT_FOUND)
.setMetadata(metadata, typeExprArg.id())
.build());

CelType checkedType = getCheckedTypeOrThrow(typeExprArg);
CelType checkedTypeValue = typeResolver.adaptType(checkedType);
return IntermediateResult.create(
typeResolver.resolveObjectType(argResult.value(), checkedTypeValue));
Expand Down Expand Up @@ -736,7 +728,9 @@ private Optional<IntermediateResult> maybeEvalOptionalSelectField(
}

String field = callExpr.args().get(1).constant().stringValue();
boolean hasField = (boolean) typeProvider.hasField(lhsResult.value(), field);
CelType checkedType = getCheckedTypeOrThrow(expr);
boolean hasField =
(boolean) typeProvider.hasField(checkedType.name(), lhsResult.value(), field);
if (!hasField) {
// Protobuf sets default (zero) values to uninitialized fields.
// In case of CEL's optional values, we want to explicitly return Optional.none()
Expand Down Expand Up @@ -980,6 +974,19 @@ private IntermediateResult evalCelBlock(

return evalInternal(frame, blockCall.args().get(1));
}

private CelType getCheckedTypeOrThrow(CelExpr expr) throws CelEvaluationException {
return ast.getType(expr.id())
.orElseThrow(
() ->
CelEvaluationExceptionBuilder.newBuilder(
"expected a runtime type for expression ID '%d' from checked expression,"
+ " but found none.",
expr.id())
.setErrorCode(CelErrorCode.TYPE_NOT_FOUND)
.setMetadata(metadata, expr.id())
.build());
}
}

/** Contains a CelExpr that is to be lazily evaluated. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public DescriptorMessageProvider(ProtoMessageFactory protoMessageFactory, CelOpt

@Override
@SuppressWarnings("unchecked")
public @Nullable Object selectField(Object message, String fieldName) {
public @Nullable Object selectField(String unusedTypeName, Object message, String fieldName) {
boolean isOptionalMessage = false;
if (message instanceof Optional) {
isOptionalMessage = true;
Expand Down Expand Up @@ -139,15 +139,16 @@ public DescriptorMessageProvider(ProtoMessageFactory protoMessageFactory, CelOpt

/** Adapt object to its message value. */
@Override
public Object adapt(Object message) {
public Object adapt(String messageName, Object message) {
if (message instanceof Message) {
return protoAdapter.adaptProtoToValue((Message) message);
}

return message;
}

@Override
public Object hasField(Object message, String fieldName) {
public Object hasField(String messageName, Object message, String fieldName) {
if (message instanceof Optional<?>) {
Optional<?> optionalMessage = (Optional<?>) message;
if (!optionalMessage.isPresent()) {
Expand Down
6 changes: 3 additions & 3 deletions runtime/src/main/java/dev/cel/runtime/LiteRuntimeImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,17 @@ public Object createMessage(String messageName, Map<String, Object> values) {
}

@Override
public Object selectField(Object message, String fieldName) {
public Object selectField(String typeName, Object message, String fieldName) {
throw new UnsupportedOperationException("Not implemented yet");
}

@Override
public Object hasField(Object message, String fieldName) {
public Object hasField(String messageName, Object message, String fieldName) {
throw new UnsupportedOperationException("Not implemented yet");
}

@Override
public Object adapt(Object message) {
public Object adapt(String messageName, Object message) {
if (message instanceof MessageLiteOrBuilder) {
throw new UnsupportedOperationException("Not implemented yet");
}
Expand Down
8 changes: 4 additions & 4 deletions runtime/src/main/java/dev/cel/runtime/MessageProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ public interface MessageProvider {
Object createMessage(String messageName, Map<String, Object> values);

/** Select field from message. */
Object selectField(Object message, String fieldName);
Object selectField(String messageName, Object message, String fieldName);

/** Check whether a field is set on message. */
Object hasField(Object message, String fieldName);
Object hasField(String messageName, Object message, String fieldName);

/** Adapt object to its message value with source location metadata on failure . */
Object adapt(Object message);
/** Adapt object to its message value with source location metadata on failure. */
Object adapt(String messageName, Object message);
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ public Object createMessage(String messageName, Map<String, Object> values) {
}

@Override
public Object selectField(Object message, String fieldName) {
public Object selectField(String typeName, Object message, String fieldName) {
return null;
}

@Override
public Object hasField(Object message, String fieldName) {
public Object hasField(String messageName, Object message, String fieldName) {
return null;
}

@Override
public Object adapt(Object message) {
public Object adapt(String messageName, Object message) {
return message;
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,24 +140,40 @@ public void createMessage_badFieldError() {

@Test
public void hasField_mapKeyFound() {
assertThat(provider.hasField(ImmutableMap.of("hello", "world"), "hello")).isEqualTo(true);
assertThat(
provider.hasField(
TestAllTypes.getDescriptor().getFullName(),
ImmutableMap.of("hello", "world"),
"hello"))
.isEqualTo(true);
}

@Test
public void hasField_mapKeyNotFound() {
assertThat(provider.hasField(ImmutableMap.of(), "hello")).isEqualTo(false);
assertThat(
provider.hasField(
TestAllTypes.getDescriptor().getFullName(), ImmutableMap.of(), "hello"))
.isEqualTo(false);
}

@Test
public void selectField_mapKeyFound() {
assertThat(provider.selectField(ImmutableMap.of("hello", "world"), "hello")).isEqualTo("world");
assertThat(
provider.selectField(
TestAllTypes.getDescriptor().getFullName(),
ImmutableMap.of("hello", "world"),
"hello"))
.isEqualTo("world");
}

@Test
public void selectField_mapKeyNotFound() {
CelRuntimeException e =
Assert.assertThrows(
CelRuntimeException.class, () -> provider.selectField(ImmutableMap.of(), "hello"));
CelRuntimeException.class,
() ->
provider.selectField(
TestAllTypes.getDescriptor().getFullName(), ImmutableMap.of(), "hello"));
assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class);
assertThat(e.getErrorCode()).isEqualTo(CelErrorCode.ATTRIBUTE_NOT_FOUND);
}
Expand All @@ -166,6 +182,7 @@ public void selectField_mapKeyNotFound() {
public void selectField_unsetWrapperField() {
assertThat(
provider.selectField(
TestAllTypes.getDescriptor().getFullName(),
dev.cel.expr.conformance.proto3.TestAllTypes.getDefaultInstance(),
"single_int64_wrapper"))
.isEqualTo(NullValue.NULL_VALUE);
Expand All @@ -175,7 +192,10 @@ public void selectField_unsetWrapperField() {
public void selectField_nonProtoObjectError() {
CelRuntimeException e =
Assert.assertThrows(
CelRuntimeException.class, () -> provider.selectField("hello", "not_a_field"));
CelRuntimeException.class,
() ->
provider.selectField(
TestAllTypes.getDescriptor().getFullName(), "hello", "not_a_field"));
assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class);
assertThat(e.getErrorCode()).isEqualTo(CelErrorCode.ATTRIBUTE_NOT_FOUND);
}
Expand All @@ -194,6 +214,7 @@ public void selectField_extensionUsingDynamicTypes() {
long result =
(long)
provider.selectField(
TestAllTypes.getDescriptor().getFullName(),
TestAllTypes.newBuilder().setExtension(TestAllTypesExtensions.int32Ext, 10).build(),
TestAllTypesProto.getDescriptor().getPackage() + ".int32_ext");

Expand Down
Loading