16
16
17
17
import static com .google .common .base .Preconditions .checkNotNull ;
18
18
19
+ import com .google .common .annotations .VisibleForTesting ;
20
+ import com .google .common .base .Defaults ;
21
+ import com .google .common .collect .ImmutableList ;
22
+ import com .google .common .collect .ImmutableMap ;
23
+ import com .google .common .primitives .UnsignedLong ;
19
24
import com .google .errorprone .annotations .Immutable ;
25
+ import com .google .protobuf .ByteString ;
26
+ import com .google .protobuf .CodedInputStream ;
27
+ import com .google .protobuf .ExtensionRegistryLite ;
20
28
import com .google .protobuf .MessageLite ;
29
+ import com .google .protobuf .WireFormat ;
21
30
import dev .cel .common .annotations .Internal ;
22
31
import dev .cel .common .internal .CelLiteDescriptorPool ;
23
32
import dev .cel .common .internal .WellKnownProto ;
33
+ import dev .cel .protobuf .CelLiteDescriptor .FieldLiteDescriptor ;
34
+ import dev .cel .protobuf .CelLiteDescriptor .FieldLiteDescriptor .CelFieldValueType ;
35
+ import dev .cel .protobuf .CelLiteDescriptor .FieldLiteDescriptor .JavaType ;
24
36
import dev .cel .protobuf .CelLiteDescriptor .MessageLiteDescriptor ;
37
+ import java .io .IOException ;
38
+ import java .util .AbstractMap ;
39
+ import java .util .ArrayList ;
40
+ import java .util .Collection ;
41
+ import java .util .LinkedHashMap ;
42
+ import java .util .List ;
43
+ import java .util .Map ;
25
44
26
45
/**
27
46
* {@code ProtoLiteCelValueConverter} handles bidirectional conversion between native Java and
@@ -43,6 +62,270 @@ public static ProtoLiteCelValueConverter newInstance(
43
62
return new ProtoLiteCelValueConverter (celLiteDescriptorPool );
44
63
}
45
64
65
+ private static Object readPrimitiveField (
66
+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
67
+ switch (fieldDescriptor .getProtoFieldType ()) {
68
+ case SINT32 :
69
+ return inputStream .readSInt32 ();
70
+ case SINT64 :
71
+ return inputStream .readSInt64 ();
72
+ case INT32 :
73
+ case ENUM :
74
+ return inputStream .readInt32 ();
75
+ case INT64 :
76
+ return inputStream .readInt64 ();
77
+ case UINT32 :
78
+ return UnsignedLong .fromLongBits (inputStream .readUInt32 ());
79
+ case UINT64 :
80
+ return UnsignedLong .fromLongBits (inputStream .readUInt64 ());
81
+ case BOOL :
82
+ return inputStream .readBool ();
83
+ case FLOAT :
84
+ case FIXED32 :
85
+ case SFIXED32 :
86
+ return readFixed32BitField (inputStream , fieldDescriptor );
87
+ case DOUBLE :
88
+ case FIXED64 :
89
+ case SFIXED64 :
90
+ return readFixed64BitField (inputStream , fieldDescriptor );
91
+ default :
92
+ throw new IllegalStateException (
93
+ "Unexpected field type: " + fieldDescriptor .getProtoFieldType ());
94
+ }
95
+ }
96
+
97
+ private static Object readFixed32BitField (
98
+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
99
+ switch (fieldDescriptor .getProtoFieldType ()) {
100
+ case FLOAT :
101
+ return inputStream .readFloat ();
102
+ case FIXED32 :
103
+ case SFIXED32 :
104
+ return inputStream .readRawLittleEndian32 ();
105
+ default :
106
+ throw new IllegalStateException (
107
+ "Unexpected field type: " + fieldDescriptor .getProtoFieldType ());
108
+ }
109
+ }
110
+
111
+ private static Object readFixed64BitField (
112
+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
113
+ switch (fieldDescriptor .getProtoFieldType ()) {
114
+ case DOUBLE :
115
+ return inputStream .readDouble ();
116
+ case FIXED64 :
117
+ case SFIXED64 :
118
+ return inputStream .readRawLittleEndian64 ();
119
+ default :
120
+ throw new IllegalStateException (
121
+ "Unexpected field type: " + fieldDescriptor .getProtoFieldType ());
122
+ }
123
+ }
124
+
125
+ private Object readLengthDelimitedField (
126
+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
127
+ FieldLiteDescriptor .Type fieldType = fieldDescriptor .getProtoFieldType ();
128
+
129
+ switch (fieldType ) {
130
+ case BYTES :
131
+ return inputStream .readBytes ();
132
+ case MESSAGE :
133
+ MessageLite .Builder builder =
134
+ getDefaultMessageBuilder (fieldDescriptor .getFieldProtoTypeName ());
135
+
136
+ inputStream .readMessage (builder , ExtensionRegistryLite .getEmptyRegistry ());
137
+ return builder .build ();
138
+ case STRING :
139
+ return inputStream .readStringRequireUtf8 ();
140
+ default :
141
+ throw new IllegalStateException ("Unexpected field type: " + fieldType );
142
+ }
143
+ }
144
+
145
+ private MessageLite .Builder getDefaultMessageBuilder (String protoTypeName ) {
146
+ return descriptorPool .getDescriptorOrThrow (protoTypeName ).newMessageBuilder ();
147
+ }
148
+
149
+ CelValue getDefaultCelValue (String protoTypeName , String fieldName ) {
150
+ MessageLiteDescriptor messageDescriptor = descriptorPool .getDescriptorOrThrow (protoTypeName );
151
+ FieldLiteDescriptor fieldDescriptor = messageDescriptor .getByFieldNameOrThrow (fieldName );
152
+
153
+ Object defaultValue = getDefaultValue (fieldDescriptor );
154
+ if (defaultValue instanceof MessageLite ) {
155
+ return fromProtoMessageToCelValue (
156
+ fieldDescriptor .getFieldProtoTypeName (), (MessageLite ) defaultValue );
157
+ } else {
158
+ return fromJavaObjectToCelValue (getDefaultValue (fieldDescriptor ));
159
+ }
160
+ }
161
+
162
+ private Object getDefaultValue (FieldLiteDescriptor fieldDescriptor ) {
163
+ FieldLiteDescriptor .CelFieldValueType celFieldValueType =
164
+ fieldDescriptor .getCelFieldValueType ();
165
+ switch (celFieldValueType ) {
166
+ case LIST :
167
+ return ImmutableList .of ();
168
+ case MAP :
169
+ return ImmutableMap .of ();
170
+ case SCALAR :
171
+ return getScalarDefaultValue (fieldDescriptor );
172
+ }
173
+ throw new IllegalStateException ("Unexpected cel field value type: " + celFieldValueType );
174
+ }
175
+
176
+ private Object getScalarDefaultValue (FieldLiteDescriptor fieldDescriptor ) {
177
+ JavaType type = fieldDescriptor .getJavaType ();
178
+ switch (type ) {
179
+ case INT :
180
+ return fieldDescriptor .getProtoFieldType ().equals (FieldLiteDescriptor .Type .UINT32 )
181
+ ? UnsignedLong .ZERO
182
+ : Defaults .defaultValue (long .class );
183
+ case LONG :
184
+ return fieldDescriptor .getProtoFieldType ().equals (FieldLiteDescriptor .Type .UINT64 )
185
+ ? UnsignedLong .ZERO
186
+ : Defaults .defaultValue (long .class );
187
+ case ENUM :
188
+ return Defaults .defaultValue (long .class );
189
+ case FLOAT :
190
+ return Defaults .defaultValue (float .class );
191
+ case DOUBLE :
192
+ return Defaults .defaultValue (double .class );
193
+ case BOOLEAN :
194
+ return Defaults .defaultValue (boolean .class );
195
+ case STRING :
196
+ return "" ;
197
+ case BYTE_STRING :
198
+ return ByteString .EMPTY ;
199
+ case MESSAGE :
200
+ if (WellKnownProto .isWrapperType (fieldDescriptor .getFieldProtoTypeName ())) {
201
+ return NullValue .NULL_VALUE ;
202
+ }
203
+
204
+ return getDefaultMessageBuilder (fieldDescriptor .getFieldProtoTypeName ()).build ();
205
+ }
206
+ throw new IllegalStateException ("Unexpected java type: " + type );
207
+ }
208
+
209
+ private ImmutableList <Object > readPackedRepeatedFields (
210
+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
211
+ int length = inputStream .readInt32 ();
212
+ int oldLimit = inputStream .pushLimit (length );
213
+ ImmutableList .Builder <Object > builder = ImmutableList .builder ();
214
+ while (inputStream .getBytesUntilLimit () > 0 ) {
215
+ builder .add (readPrimitiveField (inputStream , fieldDescriptor ));
216
+ }
217
+ inputStream .popLimit (oldLimit );
218
+ return builder .build ();
219
+ }
220
+
221
+ private Map .Entry <Object , Object > readSingleMapEntry (
222
+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
223
+ ImmutableMap <String , Object > singleMapEntry =
224
+ readAllFields (inputStream .readByteArray (), fieldDescriptor .getFieldProtoTypeName ());
225
+ Object key = checkNotNull (singleMapEntry .get ("key" ));
226
+ Object value = checkNotNull (singleMapEntry .get ("value" ));
227
+
228
+ return new AbstractMap .SimpleEntry <>(key , value );
229
+ }
230
+
231
+ @ VisibleForTesting
232
+ ImmutableMap <String , Object > readAllFields (byte [] bytes , String protoTypeName )
233
+ throws IOException {
234
+ // TODO: Handle unknown fields by collecting them into a separate map.
235
+ MessageLiteDescriptor messageDescriptor = descriptorPool .getDescriptorOrThrow (protoTypeName );
236
+ CodedInputStream inputStream = CodedInputStream .newInstance (bytes );
237
+
238
+ ImmutableMap .Builder <String , Object > fieldValues = ImmutableMap .builder ();
239
+ Map <Integer , List <Object >> repeatedFieldValues = new LinkedHashMap <>();
240
+ Map <Integer , Map <Object , Object >> mapFieldValues = new LinkedHashMap <>();
241
+ for (int tag = inputStream .readTag (); tag != 0 ; tag = inputStream .readTag ()) {
242
+ int tagWireType = WireFormat .getTagWireType (tag );
243
+ int fieldNumber = WireFormat .getTagFieldNumber (tag );
244
+ FieldLiteDescriptor fieldDescriptor = messageDescriptor .getByFieldNumberOrThrow (fieldNumber );
245
+
246
+ Object payload ;
247
+ switch (tagWireType ) {
248
+ case WireFormat .WIRETYPE_VARINT :
249
+ payload = readPrimitiveField (inputStream , fieldDescriptor );
250
+ break ;
251
+ case WireFormat .WIRETYPE_FIXED32 :
252
+ payload = readFixed32BitField (inputStream , fieldDescriptor );
253
+ break ;
254
+ case WireFormat .WIRETYPE_FIXED64 :
255
+ payload = readFixed64BitField (inputStream , fieldDescriptor );
256
+ break ;
257
+ case WireFormat .WIRETYPE_LENGTH_DELIMITED :
258
+ CelFieldValueType celFieldValueType = fieldDescriptor .getCelFieldValueType ();
259
+ switch (celFieldValueType ) {
260
+ case LIST :
261
+ if (fieldDescriptor .getIsPacked ()) {
262
+ payload = readPackedRepeatedFields (inputStream , fieldDescriptor );
263
+ } else {
264
+ FieldLiteDescriptor .Type protoFieldType = fieldDescriptor .getProtoFieldType ();
265
+ boolean isLenDelimited =
266
+ protoFieldType .equals (FieldLiteDescriptor .Type .MESSAGE )
267
+ || protoFieldType .equals (FieldLiteDescriptor .Type .STRING )
268
+ || protoFieldType .equals (FieldLiteDescriptor .Type .BYTES );
269
+ if (!isLenDelimited ) {
270
+ throw new IllegalStateException (
271
+ "Unexpected field type encountered for LEN-Delimited record: "
272
+ + protoFieldType );
273
+ }
274
+
275
+ payload = readLengthDelimitedField (inputStream , fieldDescriptor );
276
+ }
277
+ break ;
278
+ case MAP :
279
+ Map <Object , Object > fieldMap =
280
+ mapFieldValues .computeIfAbsent (fieldNumber , (unused ) -> new LinkedHashMap <>());
281
+ Map .Entry <Object , Object > mapEntry = readSingleMapEntry (inputStream , fieldDescriptor );
282
+ fieldMap .put (mapEntry .getKey (), mapEntry .getValue ());
283
+ payload = fieldMap ;
284
+ break ;
285
+ default :
286
+ payload = readLengthDelimitedField (inputStream , fieldDescriptor );
287
+ break ;
288
+ }
289
+ break ;
290
+ case WireFormat .WIRETYPE_START_GROUP :
291
+ case WireFormat .WIRETYPE_END_GROUP :
292
+ // TODO: Support groups
293
+ throw new UnsupportedOperationException ("Groups are not supported" );
294
+ default :
295
+ throw new IllegalArgumentException ("Unexpected wire type: " + tagWireType );
296
+ }
297
+
298
+ if (fieldDescriptor .getCelFieldValueType ().equals (CelFieldValueType .LIST )) {
299
+ String fieldName = fieldDescriptor .getFieldName ();
300
+ List <Object > repeatedValues =
301
+ repeatedFieldValues .computeIfAbsent (
302
+ fieldNumber ,
303
+ (unused ) -> {
304
+ List <Object > newList = new ArrayList <>();
305
+ fieldValues .put (fieldName , newList );
306
+ return newList ;
307
+ });
308
+
309
+ if (payload instanceof Collection ) {
310
+ repeatedValues .addAll ((Collection <?>) payload );
311
+ } else {
312
+ repeatedValues .add (payload );
313
+ }
314
+ } else {
315
+ fieldValues .put (fieldDescriptor .getFieldName (), payload );
316
+ }
317
+ }
318
+
319
+ // Protobuf encoding follows a "last one wins" semantics. This means for duplicated fields,
320
+ // we accept the last value encountered.
321
+ return fieldValues .buildKeepingLast ();
322
+ }
323
+
324
+ ImmutableMap <String , Object > readAllFields (MessageLite msg , String protoTypeName )
325
+ throws IOException {
326
+ return readAllFields (msg .toByteArray (), protoTypeName );
327
+ }
328
+
46
329
@ Override
47
330
public CelValue fromProtoMessageToCelValue (String protoTypeName , MessageLite msg ) {
48
331
checkNotNull (msg );
0 commit comments