Skip to content

Commit 70ddd44

Browse files
l46kokcopybara-github
authored andcommitted
Add ProtoMessageLiteValueProvider
PiperOrigin-RevId: 749868732
1 parent 220312c commit 70ddd44

File tree

8 files changed

+693
-2
lines changed

8 files changed

+693
-2
lines changed

common/src/main/java/dev/cel/common/values/BUILD.bazel

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,26 @@ java_library(
179179
"//protobuf:cel_lite_descriptor",
180180
"@maven//:com_google_errorprone_error_prone_annotations",
181181
"@maven//:com_google_guava_guava",
182+
"@maven//:com_google_protobuf_protobuf_java",
182183
"@maven//:org_jspecify_jspecify",
183184
"@maven_android//:com_google_protobuf_protobuf_javalite",
184185
],
185186
)
187+
188+
java_library(
189+
name = "proto_message_lite_value_provider",
190+
srcs = ["ProtoMessageLiteValueProvider.java"],
191+
tags = [
192+
],
193+
deps = [
194+
":cel_value",
195+
":cel_value_provider",
196+
":proto_message_lite_value",
197+
"//common/internal:cel_lite_descriptor_pool",
198+
"//common/internal:default_lite_descriptor_pool",
199+
"//protobuf:cel_lite_descriptor",
200+
"@maven//:com_google_errorprone_error_prone_annotations",
201+
"@maven//:com_google_guava_guava",
202+
"@maven_android//:com_google_protobuf_protobuf_javalite",
203+
],
204+
)

common/src/main/java/dev/cel/common/values/ProtoLiteCelValueConverter.java

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,31 @@
1616

1717
import static com.google.common.base.Preconditions.checkNotNull;
1818

19+
import com.google.common.annotations.VisibleForTesting;
20+
import com.google.common.base.Defaults;
21+
import com.google.common.collect.ImmutableList;
22+
import com.google.common.collect.ImmutableMap;
23+
import com.google.common.primitives.UnsignedLong;
1924
import com.google.errorprone.annotations.Immutable;
25+
import com.google.protobuf.ByteString;
26+
import com.google.protobuf.CodedInputStream;
27+
import com.google.protobuf.ExtensionRegistryLite;
2028
import com.google.protobuf.MessageLite;
29+
import com.google.protobuf.WireFormat;
2130
import dev.cel.common.annotations.Internal;
2231
import dev.cel.common.internal.CelLiteDescriptorPool;
2332
import dev.cel.common.internal.WellKnownProto;
33+
import dev.cel.protobuf.CelLiteDescriptor.FieldLiteDescriptor;
34+
import dev.cel.protobuf.CelLiteDescriptor.FieldLiteDescriptor.CelFieldValueType;
35+
import dev.cel.protobuf.CelLiteDescriptor.FieldLiteDescriptor.JavaType;
2436
import dev.cel.protobuf.CelLiteDescriptor.MessageLiteDescriptor;
37+
import java.io.IOException;
38+
import java.util.AbstractMap;
39+
import java.util.ArrayList;
40+
import java.util.Collection;
41+
import java.util.LinkedHashMap;
42+
import java.util.List;
43+
import java.util.Map;
2544

2645
/**
2746
* {@code ProtoLiteCelValueConverter} handles bidirectional conversion between native Java and
@@ -43,6 +62,270 @@ public static ProtoLiteCelValueConverter newInstance(
4362
return new ProtoLiteCelValueConverter(celLiteDescriptorPool);
4463
}
4564

65+
private static Object readPrimitiveField(
66+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
67+
switch (fieldDescriptor.getProtoFieldType()) {
68+
case SINT32:
69+
return inputStream.readSInt32();
70+
case SINT64:
71+
return inputStream.readSInt64();
72+
case INT32:
73+
case ENUM:
74+
return inputStream.readInt32();
75+
case INT64:
76+
return inputStream.readInt64();
77+
case UINT32:
78+
return UnsignedLong.fromLongBits(inputStream.readUInt32());
79+
case UINT64:
80+
return UnsignedLong.fromLongBits(inputStream.readUInt64());
81+
case BOOL:
82+
return inputStream.readBool();
83+
case FLOAT:
84+
case FIXED32:
85+
case SFIXED32:
86+
return readFixed32BitField(inputStream, fieldDescriptor);
87+
case DOUBLE:
88+
case FIXED64:
89+
case SFIXED64:
90+
return readFixed64BitField(inputStream, fieldDescriptor);
91+
default:
92+
throw new IllegalStateException(
93+
"Unexpected field type: " + fieldDescriptor.getProtoFieldType());
94+
}
95+
}
96+
97+
private static Object readFixed32BitField(
98+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
99+
switch (fieldDescriptor.getProtoFieldType()) {
100+
case FLOAT:
101+
return inputStream.readFloat();
102+
case FIXED32:
103+
case SFIXED32:
104+
return inputStream.readRawLittleEndian32();
105+
default:
106+
throw new IllegalStateException(
107+
"Unexpected field type: " + fieldDescriptor.getProtoFieldType());
108+
}
109+
}
110+
111+
private static Object readFixed64BitField(
112+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
113+
switch (fieldDescriptor.getProtoFieldType()) {
114+
case DOUBLE:
115+
return inputStream.readDouble();
116+
case FIXED64:
117+
case SFIXED64:
118+
return inputStream.readRawLittleEndian64();
119+
default:
120+
throw new IllegalStateException(
121+
"Unexpected field type: " + fieldDescriptor.getProtoFieldType());
122+
}
123+
}
124+
125+
private Object readLengthDelimitedField(
126+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
127+
FieldLiteDescriptor.Type fieldType = fieldDescriptor.getProtoFieldType();
128+
129+
switch (fieldType) {
130+
case BYTES:
131+
return inputStream.readBytes();
132+
case MESSAGE:
133+
MessageLite.Builder builder =
134+
getDefaultMessageBuilder(fieldDescriptor.getFieldProtoTypeName());
135+
136+
inputStream.readMessage(builder, ExtensionRegistryLite.getEmptyRegistry());
137+
return builder.build();
138+
case STRING:
139+
return inputStream.readStringRequireUtf8();
140+
default:
141+
throw new IllegalStateException("Unexpected field type: " + fieldType);
142+
}
143+
}
144+
145+
private MessageLite.Builder getDefaultMessageBuilder(String protoTypeName) {
146+
return descriptorPool.getDescriptorOrThrow(protoTypeName).newMessageBuilder();
147+
}
148+
149+
CelValue getDefaultCelValue(String protoTypeName, String fieldName) {
150+
MessageLiteDescriptor messageDescriptor = descriptorPool.getDescriptorOrThrow(protoTypeName);
151+
FieldLiteDescriptor fieldDescriptor = messageDescriptor.getByFieldNameOrThrow(fieldName);
152+
153+
Object defaultValue = getDefaultValue(fieldDescriptor);
154+
if (defaultValue instanceof MessageLite) {
155+
return fromProtoMessageToCelValue(
156+
fieldDescriptor.getFieldProtoTypeName(), (MessageLite) defaultValue);
157+
} else {
158+
return fromJavaObjectToCelValue(getDefaultValue(fieldDescriptor));
159+
}
160+
}
161+
162+
private Object getDefaultValue(FieldLiteDescriptor fieldDescriptor) {
163+
FieldLiteDescriptor.CelFieldValueType celFieldValueType =
164+
fieldDescriptor.getCelFieldValueType();
165+
switch (celFieldValueType) {
166+
case LIST:
167+
return ImmutableList.of();
168+
case MAP:
169+
return ImmutableMap.of();
170+
case SCALAR:
171+
return getScalarDefaultValue(fieldDescriptor);
172+
}
173+
throw new IllegalStateException("Unexpected cel field value type: " + celFieldValueType);
174+
}
175+
176+
private Object getScalarDefaultValue(FieldLiteDescriptor fieldDescriptor) {
177+
JavaType type = fieldDescriptor.getJavaType();
178+
switch (type) {
179+
case INT:
180+
return fieldDescriptor.getProtoFieldType().equals(FieldLiteDescriptor.Type.UINT32)
181+
? UnsignedLong.ZERO
182+
: Defaults.defaultValue(long.class);
183+
case LONG:
184+
return fieldDescriptor.getProtoFieldType().equals(FieldLiteDescriptor.Type.UINT64)
185+
? UnsignedLong.ZERO
186+
: Defaults.defaultValue(long.class);
187+
case ENUM:
188+
return Defaults.defaultValue(long.class);
189+
case FLOAT:
190+
return Defaults.defaultValue(float.class);
191+
case DOUBLE:
192+
return Defaults.defaultValue(double.class);
193+
case BOOLEAN:
194+
return Defaults.defaultValue(boolean.class);
195+
case STRING:
196+
return "";
197+
case BYTE_STRING:
198+
return ByteString.EMPTY;
199+
case MESSAGE:
200+
if (WellKnownProto.isWrapperType(fieldDescriptor.getFieldProtoTypeName())) {
201+
return NullValue.NULL_VALUE;
202+
}
203+
204+
return getDefaultMessageBuilder(fieldDescriptor.getFieldProtoTypeName()).build();
205+
}
206+
throw new IllegalStateException("Unexpected java type: " + type);
207+
}
208+
209+
private ImmutableList<Object> readPackedRepeatedFields(
210+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
211+
int length = inputStream.readInt32();
212+
int oldLimit = inputStream.pushLimit(length);
213+
ImmutableList.Builder<Object> builder = ImmutableList.builder();
214+
while (inputStream.getBytesUntilLimit() > 0) {
215+
builder.add(readPrimitiveField(inputStream, fieldDescriptor));
216+
}
217+
inputStream.popLimit(oldLimit);
218+
return builder.build();
219+
}
220+
221+
private Map.Entry<Object, Object> readSingleMapEntry(
222+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
223+
ImmutableMap<String, Object> singleMapEntry =
224+
readAllFields(inputStream.readByteArray(), fieldDescriptor.getFieldProtoTypeName());
225+
Object key = checkNotNull(singleMapEntry.get("key"));
226+
Object value = checkNotNull(singleMapEntry.get("value"));
227+
228+
return new AbstractMap.SimpleEntry<>(key, value);
229+
}
230+
231+
@VisibleForTesting
232+
ImmutableMap<String, Object> readAllFields(byte[] bytes, String protoTypeName)
233+
throws IOException {
234+
// TODO: Handle unknown fields by collecting them into a separate map.
235+
MessageLiteDescriptor messageDescriptor = descriptorPool.getDescriptorOrThrow(protoTypeName);
236+
CodedInputStream inputStream = CodedInputStream.newInstance(bytes);
237+
238+
ImmutableMap.Builder<String, Object> fieldValues = ImmutableMap.builder();
239+
Map<Integer, List<Object>> repeatedFieldValues = new LinkedHashMap<>();
240+
Map<Integer, Map<Object, Object>> mapFieldValues = new LinkedHashMap<>();
241+
for (int tag = inputStream.readTag(); tag != 0; tag = inputStream.readTag()) {
242+
int tagWireType = WireFormat.getTagWireType(tag);
243+
int fieldNumber = WireFormat.getTagFieldNumber(tag);
244+
FieldLiteDescriptor fieldDescriptor = messageDescriptor.getByFieldNumberOrThrow(fieldNumber);
245+
246+
Object payload;
247+
switch (tagWireType) {
248+
case WireFormat.WIRETYPE_VARINT:
249+
payload = readPrimitiveField(inputStream, fieldDescriptor);
250+
break;
251+
case WireFormat.WIRETYPE_FIXED32:
252+
payload = readFixed32BitField(inputStream, fieldDescriptor);
253+
break;
254+
case WireFormat.WIRETYPE_FIXED64:
255+
payload = readFixed64BitField(inputStream, fieldDescriptor);
256+
break;
257+
case WireFormat.WIRETYPE_LENGTH_DELIMITED:
258+
CelFieldValueType celFieldValueType = fieldDescriptor.getCelFieldValueType();
259+
switch (celFieldValueType) {
260+
case LIST:
261+
if (fieldDescriptor.getIsPacked()) {
262+
payload = readPackedRepeatedFields(inputStream, fieldDescriptor);
263+
} else {
264+
FieldLiteDescriptor.Type protoFieldType = fieldDescriptor.getProtoFieldType();
265+
boolean isLenDelimited =
266+
protoFieldType.equals(FieldLiteDescriptor.Type.MESSAGE)
267+
|| protoFieldType.equals(FieldLiteDescriptor.Type.STRING)
268+
|| protoFieldType.equals(FieldLiteDescriptor.Type.BYTES);
269+
if (!isLenDelimited) {
270+
throw new IllegalStateException(
271+
"Unexpected field type encountered for LEN-Delimited record: "
272+
+ protoFieldType);
273+
}
274+
275+
payload = readLengthDelimitedField(inputStream, fieldDescriptor);
276+
}
277+
break;
278+
case MAP:
279+
Map<Object, Object> fieldMap =
280+
mapFieldValues.computeIfAbsent(fieldNumber, (unused) -> new LinkedHashMap<>());
281+
Map.Entry<Object, Object> mapEntry = readSingleMapEntry(inputStream, fieldDescriptor);
282+
fieldMap.put(mapEntry.getKey(), mapEntry.getValue());
283+
payload = fieldMap;
284+
break;
285+
default:
286+
payload = readLengthDelimitedField(inputStream, fieldDescriptor);
287+
break;
288+
}
289+
break;
290+
case WireFormat.WIRETYPE_START_GROUP:
291+
case WireFormat.WIRETYPE_END_GROUP:
292+
// TODO: Support groups
293+
throw new UnsupportedOperationException("Groups are not supported");
294+
default:
295+
throw new IllegalArgumentException("Unexpected wire type: " + tagWireType);
296+
}
297+
298+
if (fieldDescriptor.getCelFieldValueType().equals(CelFieldValueType.LIST)) {
299+
String fieldName = fieldDescriptor.getFieldName();
300+
List<Object> repeatedValues =
301+
repeatedFieldValues.computeIfAbsent(
302+
fieldNumber,
303+
(unused) -> {
304+
List<Object> newList = new ArrayList<>();
305+
fieldValues.put(fieldName, newList);
306+
return newList;
307+
});
308+
309+
if (payload instanceof Collection) {
310+
repeatedValues.addAll((Collection<?>) payload);
311+
} else {
312+
repeatedValues.add(payload);
313+
}
314+
} else {
315+
fieldValues.put(fieldDescriptor.getFieldName(), payload);
316+
}
317+
}
318+
319+
// Protobuf encoding follows a "last one wins" semantics. This means for duplicated fields,
320+
// we accept the last value encountered.
321+
return fieldValues.buildKeepingLast();
322+
}
323+
324+
ImmutableMap<String, Object> readAllFields(MessageLite msg, String protoTypeName)
325+
throws IOException {
326+
return readAllFields(msg.toByteArray(), protoTypeName);
327+
}
328+
46329
@Override
47330
public CelValue fromProtoMessageToCelValue(String protoTypeName, MessageLite msg) {
48331
checkNotNull(msg);

common/src/main/java/dev/cel/common/values/ProtoMessageLiteValue.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
package dev.cel.common.values;
1616

1717
import com.google.auto.value.AutoValue;
18+
import com.google.auto.value.extension.memoized.Memoized;
1819
import com.google.common.base.Preconditions;
20+
import com.google.common.collect.ImmutableMap;
1921
import com.google.errorprone.annotations.Immutable;
2022
import com.google.protobuf.MessageLite;
2123
import dev.cel.common.types.CelType;
2224
import dev.cel.common.types.StructTypeReference;
25+
import java.io.IOException;
2326
import java.util.Optional;
2427

2528
/**
@@ -42,19 +45,32 @@ public abstract class ProtoMessageLiteValue extends StructValue<StringValue> {
4245

4346
abstract ProtoLiteCelValueConverter protoLiteCelValueConverter();
4447

48+
@Memoized
49+
ImmutableMap<String, Object> fieldValues() {
50+
try {
51+
return protoLiteCelValueConverter().readAllFields(value(), celType().name());
52+
} catch (IOException e) {
53+
throw new IllegalStateException("Unable to read message fields for " + celType().name(), e);
54+
}
55+
}
56+
4557
@Override
4658
public boolean isZeroValue() {
4759
return value().getDefaultInstanceForType().equals(value());
4860
}
4961

5062
@Override
5163
public CelValue select(StringValue field) {
52-
throw new UnsupportedOperationException("Not implemented yet");
64+
return find(field)
65+
.orElseGet(
66+
() -> protoLiteCelValueConverter().getDefaultCelValue(celType().name(), field.value()));
5367
}
5468

5569
@Override
5670
public Optional<CelValue> find(StringValue field) {
57-
throw new UnsupportedOperationException("Not implemented yet");
71+
Object fieldValue = fieldValues().get(field.value());
72+
return Optional.ofNullable(fieldValue)
73+
.map(value -> protoLiteCelValueConverter().fromJavaObjectToCelValue(fieldValue));
5874
}
5975

6076
public static ProtoMessageLiteValue create(

0 commit comments

Comments
 (0)