Skip to content

Fixing ExprValueUtils to return a Message using runtime generated extensionRegistry and typeRegistry. #686

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions testing/src/main/java/dev/cel/testing/testrunner/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ java_library(
java_library(
name = "registry_utils",
srcs = ["RegistryUtils.java"],
tags = [
],
deps = [
"//common:cel_descriptors",
"//common/internal:cel_descriptor_pools",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,27 @@
import java.util.Set;

/** Utility class for creating registries from a file descriptor set. */
final class RegistryUtils {
public final class RegistryUtils {

private RegistryUtils() {}

/** Returns the {@link FileDescriptorSet} for the given file descriptor set path. */
static FileDescriptorSet getFileDescriptorSet(String fileDescriptorSetPath) throws IOException {
public static FileDescriptorSet getFileDescriptorSet(String fileDescriptorSetPath)
throws IOException {
// We can pass an empty extension registry here because extensions are recovered later when
// creating the extension registry in {@link #createExtensionRegistry}.
return FileDescriptorSet.parseFrom(
Files.toByteArray(new File(fileDescriptorSetPath)), ExtensionRegistry.newInstance());
}

/** Returns the {@link TypeRegistry} for the given file descriptor set. */
static TypeRegistry getTypeRegistry(Set<FileDescriptor> fileDescriptors) throws IOException {
public static TypeRegistry getTypeRegistry(Set<FileDescriptor> fileDescriptors)
throws IOException {
return createTypeRegistry(fileDescriptors);
}

/** Returns the {@link ExtensionRegistry} for the given file descriptor set. */
static ExtensionRegistry getExtensionRegistry(Set<FileDescriptor> fileDescriptors)
public static ExtensionRegistry getExtensionRegistry(Set<FileDescriptor> fileDescriptors)
throws IOException {
return createExtensionRegistry(fileDescriptors);
}
Expand Down
3 changes: 3 additions & 0 deletions testing/src/main/java/dev/cel/testing/utils/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ java_library(
],
deps = [
"//common:cel_descriptors",
"//common/internal:cel_descriptor_pools",
"//common/internal:default_instance_message_factory",
"//common/internal:default_message_factory",
"//common/types",
"//common/types:type_providers",
"//testing/testrunner:registry_utils",
"@cel_spec//proto/cel/expr:expr_java_proto",
"@cel_spec//proto/cel/expr/conformance/proto2:test_all_types_java_proto",
"@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto",
Expand Down
58 changes: 58 additions & 0 deletions testing/src/main/java/dev/cel/testing/utils/ExprValueUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,29 @@
import dev.cel.expr.Value;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.primitives.UnsignedLong;
import com.google.protobuf.Any;
import com.google.protobuf.ByteString;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import com.google.protobuf.NullValue;
import com.google.protobuf.TypeRegistry;
import dev.cel.common.CelDescriptorUtil;
import dev.cel.common.CelDescriptors;
import dev.cel.common.internal.CelDescriptorPool;
import dev.cel.common.internal.DefaultDescriptorPool;
import dev.cel.common.internal.DefaultInstanceMessageFactory;
import dev.cel.common.internal.DefaultMessageFactory;
import dev.cel.common.types.CelType;
import dev.cel.common.types.ListType;
import dev.cel.common.types.MapType;
import dev.cel.common.types.OptionalType;
import dev.cel.common.types.SimpleType;
import dev.cel.common.types.TypeType;
import dev.cel.testing.testrunner.RegistryUtils;
import java.io.IOException;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -79,6 +85,19 @@ public static Object fromValue(Value value) throws IOException {
case OBJECT_VALUE:
{
Any object = value.getObjectValue();

// If the file_descriptor_set_path is set, use the provided file descriptor set created at
// runtime after deserializing the file_descriptor_set file.
// Because of the above reason, DefaultInstanceMessageFactory cannot be used since it
// would always result in a descriptor reference mismatch. Instead, we use
// DefaultMessageFactory to create a DynamicMessage and parse it with `<Any>.getValue()`.
//
// TODO: Remove DynamicMessage parsing once default instance generation is
// fixed.
String fileDescriptorSetPath = System.getProperty("file_descriptor_set_path");
if (fileDescriptorSetPath != null) {
return parseAny(object, fileDescriptorSetPath);
}
Descriptor descriptor =
DEFAULT_TYPE_REGISTRY.getDescriptorForTypeUrl(object.getTypeUrl());
Message prototype =
Expand Down Expand Up @@ -245,6 +264,45 @@ public static Value toValue(Object object, CelType type) throws Exception {
String.format("Unexpected result type: %s", object.getClass()));
}

private static Message parseAny(Any value, String fileDescriptorSetPath) throws IOException {
ImmutableSet<FileDescriptor> fileDescriptors =
CelDescriptorUtil.getFileDescriptorsFromFileDescriptorSet(
RegistryUtils.getFileDescriptorSet(fileDescriptorSetPath));

TypeRegistry typeRegistry = RegistryUtils.getTypeRegistry(fileDescriptors);
ExtensionRegistry extensionRegistry = RegistryUtils.getExtensionRegistry(fileDescriptors);

CelDescriptors allDescriptors =
CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(fileDescriptors);

CelDescriptorPool pool = DefaultDescriptorPool.create(allDescriptors);

DefaultMessageFactory defaultMessageFactory = DefaultMessageFactory.create(pool);
Descriptor descriptor = typeRegistry.getDescriptorForTypeUrl(value.getTypeUrl());

return unpackAny(value, defaultMessageFactory, descriptor, extensionRegistry);
}

private static Message unpackAny(
Any value,
DefaultMessageFactory defaultMessageFactory,
Descriptor descriptor,
ExtensionRegistry extensionRegistry)
throws IOException {
// Generate a default message for the given descriptor.
Message defaultInstance =
defaultMessageFactory
.newBuilder(descriptor.getFullName())
.orElseThrow(
() ->
new NoSuchElementException(
"Could not find a default message for: " + value.getTypeUrl()))
.build();

// Parse the default message using the value from the Any object.
return defaultInstance.getParserForType().parseFrom(value.getValue(), extensionRegistry);
}

private static ExtensionRegistry newDefaultExtensionRegistry() {
ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance();
dev.cel.expr.conformance.proto2.TestAllTypesExtensions.registerAllExtensions(extensionRegistry);
Expand Down
6 changes: 6 additions & 0 deletions testing/testrunner/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ exports_files(
srcs = ["run_testrunner_binary.sh"],
)

java_library(
name = "registry_utils",
visibility = ["//:internal"],
exports = ["//testing/src/main/java/dev/cel/testing/testrunner:registry_utils"],
)

bzl_library(
name = "cel_java_test",
srcs = ["cel_java_test.bzl"],
Expand Down
Loading