Skip to content

Separately generate lite descriptors per message #705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.io.Files;
import com.google.protobuf.DescriptorProtos.FileDescriptorProto;
import com.google.protobuf.DescriptorProtos.FileDescriptorSet;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.ExtensionRegistry;
import dev.cel.common.CelDescriptorUtil;
Expand Down Expand Up @@ -104,54 +105,54 @@ public Integer call() throws Exception {
targetDescriptorProtoPath));
}

GeneratedClass generatedClass = codegenCelLiteDescriptor(targetFileDescriptor);
debugPrinter.print("Generated Class:\n" + generatedClass.code());

generatedClassesBuilder.add(generatedClass);
ImmutableList<GeneratedClass> generatedClasses =
codegenCelLiteDescriptors(targetFileDescriptor);
generatedClassesBuilder.addAll(generatedClasses);
}

JavaFileGenerator.writeSrcJar(outPath, generatedClassesBuilder.build());

return 0;
}

private GeneratedClass codegenCelLiteDescriptor(FileDescriptor targetFileDescriptor)
throws Exception {
private ImmutableList<GeneratedClass> codegenCelLiteDescriptors(
FileDescriptor targetFileDescriptor) throws Exception {
String javaPackageName = ProtoJavaQualifiedNames.getJavaPackageName(targetFileDescriptor);
String javaClassName;

// Derive the java class name. Use first encountered message/enum in the FDS as a default,
// with a suffix applied for uniqueness (we don't want to collide with java protoc default
// generated class name).
if (!targetFileDescriptor.getMessageTypes().isEmpty()) {
javaClassName = targetFileDescriptor.getMessageTypes().get(0).getName();
} else if (!targetFileDescriptor.getEnumTypes().isEmpty()) {
javaClassName = targetFileDescriptor.getEnumTypes().get(0).getName();
} else {
throw new IllegalArgumentException("File descriptor does not contain any messages or enums!");
List<Descriptor> descriptors = targetFileDescriptor.getMessageTypes();
if (descriptors.isEmpty()) {
throw new IllegalArgumentException("File descriptor does not contain any messages!");
}

String javaSuffixName =
overriddenDescriptorClassSuffix.isEmpty()
? DEFAULT_CEL_LITE_DESCRIPTOR_CLASS_SUFFIX
: overriddenDescriptorClassSuffix;
javaClassName += javaSuffixName;

ProtoDescriptorCollector descriptorCollector =
ProtoDescriptorCollector.newInstance(debugPrinter);

debugPrinter.print(
String.format(
"Fully qualified descriptor java class name: %s.%s", javaPackageName, javaClassName));

return JavaFileGenerator.generateClass(
JavaFileGeneratorOption.newBuilder()
.setVersion(version)
.setDescriptorClassName(javaClassName)
.setPackageName(javaPackageName)
.setDescriptorMetadataList(
descriptorCollector.collectCodegenMetadata(targetFileDescriptor))
.build());
ImmutableList.Builder<GeneratedClass> generatedClassBuilder = ImmutableList.builder();
for (Descriptor messageDescriptor : descriptors) {
javaClassName = messageDescriptor.getName();
String javaSuffixName =
overriddenDescriptorClassSuffix.isEmpty()
? DEFAULT_CEL_LITE_DESCRIPTOR_CLASS_SUFFIX
: overriddenDescriptorClassSuffix;
javaClassName += javaSuffixName;

debugPrinter.print(
String.format(
"Fully qualified descriptor java class name: %s.%s", javaPackageName, javaClassName));

generatedClassBuilder.add(
JavaFileGenerator.generateClass(
JavaFileGeneratorOption.newBuilder()
.setVersion(version)
.setDescriptorClassName(javaClassName)
.setPackageName(javaPackageName)
.setDescriptorMetadataList(
descriptorCollector.collectCodegenMetadata(messageDescriptor))
.build()));
}

return generatedClassBuilder.build();
}

private String extractProtoPath(String descriptorPath) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,19 @@ final class ProtoDescriptorCollector {

private final DebugPrinter debugPrinter;

ImmutableList<LiteDescriptorCodegenMetadata> collectCodegenMetadata(
FileDescriptor targetFileDescriptor) {
ImmutableList<LiteDescriptorCodegenMetadata> collectCodegenMetadata(Descriptor descriptor) {
ImmutableList.Builder<LiteDescriptorCodegenMetadata> descriptorListBuilder =
ImmutableList.builder();
ImmutableList<Descriptor> descriptorList =
collect(targetFileDescriptor).stream()
collectNested(descriptor).stream()
// Don't collect WKTs. They are included in the default descriptor pool.
.filter(d -> !WellKnownProto.getByTypeName(d.getFullName()).isPresent())
.collect(toImmutableList());

for (Descriptor descriptor : descriptorList) {
for (Descriptor messageDescriptor : descriptorList) {
LiteDescriptorCodegenMetadata.Builder descriptorCodegenBuilder =
LiteDescriptorCodegenMetadata.newBuilder();
for (Descriptors.FieldDescriptor fieldDescriptor : descriptor.getFields()) {
for (Descriptors.FieldDescriptor fieldDescriptor : messageDescriptor.getFields()) {
FieldLiteDescriptorMetadata.Builder fieldDescriptorCodegenBuilder =
FieldLiteDescriptorMetadata.newBuilder()
.setFieldNumber(fieldDescriptor.getNumber())
Expand Down Expand Up @@ -85,16 +84,17 @@ ImmutableList<LiteDescriptorCodegenMetadata> collectCodegenMetadata(
debugPrinter.print(
String.format(
"Collecting message %s, for field %s, type: %s",
descriptor.getFullName(),
messageDescriptor.getFullName(),
fieldDescriptor.getFullName(),
fieldDescriptor.getType()));
}

descriptorCodegenBuilder.setProtoTypeName(descriptor.getFullName());
descriptorCodegenBuilder.setProtoTypeName(messageDescriptor.getFullName());
// Maps are resolved as an actual Java map, and doesn't have a MessageLite.Builder associated.
if (!descriptor.getOptions().getMapEntry()) {
if (!messageDescriptor.getOptions().getMapEntry()) {
String sanitizedJavaClassName =
ProtoJavaQualifiedNames.getFullyQualifiedJavaClassName(descriptor).replace('$', '.');
ProtoJavaQualifiedNames.getFullyQualifiedJavaClassName(messageDescriptor)
.replace('$', '.');
descriptorCodegenBuilder.setJavaClassName(sanitizedJavaClassName);
}

Expand Down Expand Up @@ -177,19 +177,17 @@ private static FieldLiteDescriptor.JavaType adaptJavaType(JavaType javaType) {
throw new IllegalArgumentException("Unknown JavaType: " + javaType);
}

private static ImmutableSet<Descriptor> collect(FileDescriptor fileDescriptor) {
private static ImmutableSet<Descriptor> collectNested(Descriptor descriptor) {
ImmutableSet.Builder<Descriptor> builder = ImmutableSet.builder();
for (Descriptor descriptor : fileDescriptor.getMessageTypes()) {
collect(builder, descriptor);
}

collectNested(builder, descriptor);
return builder.build();
}

private static void collect(ImmutableSet.Builder<Descriptor> builder, Descriptor descriptor) {
private static void collectNested(
ImmutableSet.Builder<Descriptor> builder, Descriptor descriptor) {
builder.add(descriptor);
for (Descriptor nested : descriptor.getNestedTypes()) {
collect(builder, nested);
collectNested(builder, nested);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ public void getProtoTypeNamesToDescriptors_containsAllMessages() {
assertThat(protoNamesToDescriptors).containsKey("cel.expr.conformance.proto3.TestAllTypes");
assertThat(protoNamesToDescriptors)
.containsKey("cel.expr.conformance.proto3.TestAllTypes.NestedMessage");
assertThat(protoNamesToDescriptors)
.containsKey("cel.expr.conformance.proto3.NestedTestAllTypes");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.google.common.collect.ImmutableList;
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
import dev.cel.expr.conformance.proto3.NestedTestAllTypes;
import dev.cel.expr.conformance.proto3.TestAllTypes;
import dev.cel.testing.testdata.MultiFile;
import org.junit.Test;
Expand All @@ -31,11 +32,14 @@ public void collectCodegenMetadata_containsAllDescriptors() {
ProtoDescriptorCollector collector =
ProtoDescriptorCollector.newInstance(DebugPrinter.newInstance(false));

ImmutableList<LiteDescriptorCodegenMetadata> descriptors =
collector.collectCodegenMetadata(TestAllTypes.getDescriptor().getFile());
ImmutableList<LiteDescriptorCodegenMetadata> testAllTypesDescriptors =
collector.collectCodegenMetadata(TestAllTypes.getDescriptor());
ImmutableList<LiteDescriptorCodegenMetadata> nestedTestAllTypesDescriptors =
collector.collectCodegenMetadata(NestedTestAllTypes.getDescriptor());

// All proto messages, including transitive ones + maps
assertThat(descriptors).hasSize(166);
assertThat(testAllTypesDescriptors).hasSize(165);
assertThat(nestedTestAllTypesDescriptors).hasSize(1);
}

@Test
Expand All @@ -44,7 +48,7 @@ public void collectCodegenMetadata_withProtoDependencies_containsAllDescriptors(
ProtoDescriptorCollector.newInstance(DebugPrinter.newInstance(false));

ImmutableList<LiteDescriptorCodegenMetadata> descriptors =
collector.collectCodegenMetadata(MultiFile.getDescriptor().getFile());
collector.collectCodegenMetadata(MultiFile.getDescriptor());

assertThat(descriptors).hasSize(3);
assertThat(
Expand All @@ -60,7 +64,7 @@ public void collectCodegenMetadata_withProtoDependencies_doesNotContainImportedP
ProtoDescriptorCollector.newInstance(DebugPrinter.newInstance(false));

ImmutableList<LiteDescriptorCodegenMetadata> descriptors =
collector.collectCodegenMetadata(MultiFile.getDescriptor().getFile());
collector.collectCodegenMetadata(MultiFile.getDescriptor());

assertThat(
descriptors.stream()
Expand Down
2 changes: 2 additions & 0 deletions runtime/src/test/java/dev/cel/runtime/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ java_library(
"//runtime:type_resolver",
"//runtime:unknown_attributes",
"//runtime:unknown_options",
"//testing/protos:message_with_enum_cel_java_proto",
"//testing/protos:message_with_enum_java_proto",
"//testing/protos:multi_file_cel_java_proto",
"//testing/protos:multi_file_java_proto",
"//testing/protos:single_file_java_proto",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import dev.cel.common.values.CelValueProvider;
import dev.cel.common.values.ProtoMessageLiteValueProvider;
import dev.cel.expr.conformance.proto3.NestedTestAllTypes;
import dev.cel.expr.conformance.proto3.NestedTestAllTypesCelLiteDescriptor;
import dev.cel.expr.conformance.proto3.TestAllTypes;
import dev.cel.expr.conformance.proto3.TestAllTypesCelLiteDescriptor;
import dev.cel.extensions.CelLiteExtensions;
Expand Down Expand Up @@ -320,7 +321,10 @@ public void eval_protoMessage_primitiveWithDefaults(String checkedExpr) throws E
.setValueProvider(
ProtoMessageLiteValueProvider.newInstance(
dev.cel.expr.conformance.proto2.TestAllTypesCelLiteDescriptor.getDescriptor(),
TestAllTypesCelLiteDescriptor.getDescriptor()))
dev.cel.expr.conformance.proto2.NestedTestAllTypesCelLiteDescriptor
.getDescriptor(),
TestAllTypesCelLiteDescriptor.getDescriptor(),
NestedTestAllTypesCelLiteDescriptor.getDescriptor()))
.build();
// Ensures that all branches of the OR conditions are evaluated, and that appropriate defaults
// are returned for primitives.
Expand Down Expand Up @@ -454,7 +458,10 @@ public void eval_protoMessage_safeTraversal(String checkedExpr) throws Exception
.setValueProvider(
ProtoMessageLiteValueProvider.newInstance(
dev.cel.expr.conformance.proto2.TestAllTypesCelLiteDescriptor.getDescriptor(),
TestAllTypesCelLiteDescriptor.getDescriptor()))
dev.cel.expr.conformance.proto2.NestedTestAllTypesCelLiteDescriptor
.getDescriptor(),
TestAllTypesCelLiteDescriptor.getDescriptor(),
NestedTestAllTypesCelLiteDescriptor.getDescriptor()))
.build();
// Expr: proto2.oneof_type.payload.repeated_string
CelAbstractSyntaxTree ast = readCheckedExpr(checkedExpr);
Expand Down Expand Up @@ -483,7 +490,10 @@ public void eval_protoMessage_deepTraversalReturnsRepeatedStrings(String checked
.setValueProvider(
ProtoMessageLiteValueProvider.newInstance(
dev.cel.expr.conformance.proto2.TestAllTypesCelLiteDescriptor.getDescriptor(),
TestAllTypesCelLiteDescriptor.getDescriptor()))
dev.cel.expr.conformance.proto2.NestedTestAllTypesCelLiteDescriptor
.getDescriptor(),
TestAllTypesCelLiteDescriptor.getDescriptor(),
NestedTestAllTypesCelLiteDescriptor.getDescriptor()))
.build();
// Expr: proto2.oneof_type.payload.repeated_string
CelAbstractSyntaxTree ast = readCheckedExpr(checkedExpr);
Expand Down
31 changes: 31 additions & 0 deletions runtime/src/test/java/dev/cel/runtime/CelLiteRuntimeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,11 @@
import dev.cel.expr.conformance.proto3.TestAllTypes.NestedMessage;
import dev.cel.expr.conformance.proto3.TestAllTypesCelDescriptor;
import dev.cel.parser.CelStandardMacro;
import dev.cel.testing.testdata.MessageWithEnum;
import dev.cel.testing.testdata.MessageWithEnumCelDescriptor;
import dev.cel.testing.testdata.MultiFile;
import dev.cel.testing.testdata.MultiFileCelDescriptor;
import dev.cel.testing.testdata.SimpleEnum;
import dev.cel.testing.testdata.SingleFileCelDescriptor;
import dev.cel.testing.testdata.SingleFileProto.SingleFile;
import java.util.ArrayList;
Expand Down Expand Up @@ -660,4 +663,32 @@ public void eval_dynFunctionReturnsProto() throws Exception {

assertThat(result).isEqualToDefaultInstance();
}

@Test
public void eval_withEnumField() throws Exception {
CelCompiler celCompiler =
CelCompilerFactory.standardCelCompilerBuilder()
.addVar(
"msg", StructTypeReference.create(MessageWithEnum.getDescriptor().getFullName()))
.addMessageTypes(MessageWithEnum.getDescriptor())
.build();
CelLiteRuntime celLiteRuntime =
CelLiteRuntimeFactory.newLiteRuntimeBuilder()
.setStandardFunctions(CelStandardFunctions.newBuilder().build())
.setValueProvider(
ProtoMessageLiteValueProvider.newInstance(
MessageWithEnumCelDescriptor.getDescriptor()))
.build();
CelAbstractSyntaxTree ast = celCompiler.compile("msg.simple_enum").getAst();

Long result =
(Long)
celLiteRuntime
.createProgram(ast)
.eval(
ImmutableMap.of(
"msg", MessageWithEnum.newBuilder().setSimpleEnum(SimpleEnum.BAR)));

assertThat(result).isEqualTo(SimpleEnum.BAR.getNumber());
}
}
10 changes: 10 additions & 0 deletions testing/protos/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ alias(
actual = "//testing/src/test/resources/protos:multi_file_cel_java_proto",
)

alias(
name = "message_with_enum_java_proto",
actual = "//testing/src/test/resources/protos:message_with_enum_java_proto",
)

alias(
name = "multi_file_cel_java_proto_lite",
actual = "//testing/src/test/resources/protos:multi_file_cel_java_proto_lite",
Expand All @@ -43,3 +48,8 @@ alias(
name = "test_all_types_cel_java_proto3",
actual = "//testing/src/test/resources/protos:test_all_types_cel_java_proto3",
)

alias(
name = "message_with_enum_cel_java_proto",
actual = "//testing/src/test/resources/protos:message_with_enum_cel_java_proto",
)
17 changes: 17 additions & 0 deletions testing/src/test/resources/protos/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ proto_library(
deps = [":single_file_proto"],
)

proto_library(
name = "message_with_enum_proto",
srcs = ["message_with_enum.proto"],
)

java_proto_library(
name = "message_with_enum_java_proto",
deps = [":message_with_enum_proto"],
)

# Test only. java_proto_library supports generating a jar with multiple proto deps,
# so we must test this case as well for lite descriptors.
# buildifier: disable=LANG_proto_library-single-deps
Expand Down Expand Up @@ -86,3 +96,10 @@ java_lite_proto_cel_library_impl(
java_proto_library_dep = "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto",
deps = ["@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_proto"],
)

java_lite_proto_cel_library_impl(
name = "message_with_enum_cel_java_proto",
java_descriptor_class_suffix = "CelDescriptor",
java_proto_library_dep = ":message_with_enum_java_proto",
deps = [":message_with_enum_proto"],
)
31 changes: 31 additions & 0 deletions testing/src/test/resources/protos/message_with_enum.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

syntax = "proto3";

package dev.cel.testing.testdata;

option java_multiple_files = true;
option java_package = "dev.cel.testing.testdata";
option java_outer_classname = "MessageWithEnumProto";

message MessageWithEnum {
SimpleEnum simple_enum = 1;
}

enum SimpleEnum {
FOO = 0;
BAR = 1;
BAZ = 2;
}
Loading