Skip to content

Commit fd960cf

Browse files
l46kokcopybara-github
authored andcommitted
Enhance CSE to handle two variable comprehensions
PiperOrigin-RevId: 803165826
1 parent 4a1ed0f commit fd960cf

19 files changed

+14647
-813
lines changed

extensions/src/main/java/dev/cel/extensions/CelExtensions.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,9 @@ public static ImmutableSet<String> getAllFunctionNames() {
338338
stream(CelListsExtensions.Function.values())
339339
.map(CelListsExtensions.Function::getFunction),
340340
stream(CelRegexExtensions.Function.values())
341-
.map(CelRegexExtensions.Function::getFunction))
341+
.map(CelRegexExtensions.Function::getFunction),
342+
stream(CelComprehensionsExtensions.Function.values())
343+
.map(CelComprehensionsExtensions.Function::getFunction))
342344
.collect(toImmutableSet());
343345
}
344346

extensions/src/test/java/dev/cel/extensions/CelExtensionsTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ public void getAllFunctionNames() {
187187
"lists.@sortByAssociatedKeys",
188188
"regex.replace",
189189
"regex.extract",
190-
"regex.extractAll");
190+
"regex.extractAll",
191+
"cel.@mapInsert");
191192
}
192193
}

optimizer/src/main/java/dev/cel/optimizer/AstMutator.java

Lines changed: 88 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ public final class AstMutator {
6060
private final long iterationLimit;
6161

6262
/**
63-
* Returns a new instance of a AST mutator with the iteration limit set.
63+
* Returns a new instance of an AST mutator with the iteration limit set.
6464
*
6565
* <p>Mutation is performed by walking the existing AST until the expression node to replace is
6666
* found, then the new subtree is walked to complete the mutation. Visiting of each node
@@ -205,15 +205,20 @@ public CelMutableAst renumberIdsConsecutively(CelMutableAst mutableAst) {
205205
* @param newAccuVarPrefix Prefix to use for new accumulation variable identifier name.
206206
*/
207207
public MangledComprehensionAst mangleComprehensionIdentifierNames(
208-
CelMutableAst ast, String newIterVarPrefix, String newAccuVarPrefix) {
208+
CelMutableAst ast,
209+
String newIterVarPrefix,
210+
String newIterVar2Prefix,
211+
String newAccuVarPrefix) {
209212
CelNavigableMutableAst navigableMutableAst = CelNavigableMutableAst.fromAst(ast);
210213
Predicate<CelNavigableMutableExpr> comprehensionIdentifierPredicate = x -> true;
211214
comprehensionIdentifierPredicate =
212215
comprehensionIdentifierPredicate
213216
.and(node -> node.getKind().equals(Kind.COMPREHENSION))
214-
.and(node -> !node.expr().comprehension().iterVar().startsWith(newIterVarPrefix))
215-
.and(node -> !node.expr().comprehension().accuVar().startsWith(newAccuVarPrefix));
216-
217+
.and(node -> !node.expr().comprehension().iterVar().startsWith(newIterVarPrefix + ":"))
218+
.and(node -> !node.expr().comprehension().accuVar().startsWith(newAccuVarPrefix + ":"))
219+
.and(
220+
node ->
221+
!node.expr().comprehension().iterVar2().startsWith(newIterVar2Prefix + ":"));
217222
LinkedHashMap<CelNavigableMutableExpr, MangledComprehensionType> comprehensionsToMangle =
218223
navigableMutableAst
219224
.getRoot()
@@ -226,20 +231,25 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
226231
// Ensure the iter_var or the comprehension result is actually referenced in the
227232
// loop_step. If it's not, we can skip mangling.
228233
String iterVar = node.expr().comprehension().iterVar();
234+
String iterVar2 = node.expr().comprehension().iterVar2();
229235
String result = node.expr().comprehension().result().ident().name();
230236
return CelNavigableMutableExpr.fromExpr(node.expr().comprehension().loopStep())
231237
.allNodes()
232238
.filter(subNode -> subNode.getKind().equals(Kind.IDENT))
233239
.map(subNode -> subNode.expr().ident())
234240
.anyMatch(
235-
ident -> ident.name().contains(iterVar) || ident.name().contains(result));
241+
ident ->
242+
ident.name().contains(iterVar)
243+
|| ident.name().contains(iterVar2)
244+
|| ident.name().contains(result));
236245
})
237246
.collect(
238247
Collectors.toMap(
239248
k -> k,
240249
v -> {
241250
CelMutableComprehension comprehension = v.expr().comprehension();
242251
String iterVar = comprehension.iterVar();
252+
String iterVar2 = comprehension.iterVar2();
243253
// Identifiers to mangle could be the iteration variable, comprehension
244254
// result or both, but at least one has to exist.
245255
// As an example, [1,2].map(i, 3) would result in optional.empty for iteration
@@ -253,6 +263,16 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
253263
&& loopStepNode.expr().ident().name().equals(iterVar))
254264
.map(CelNavigableMutableExpr::id)
255265
.findAny();
266+
Optional<Long> iterVar2Id =
267+
CelNavigableMutableExpr.fromExpr(comprehension.loopStep())
268+
.allNodes()
269+
.filter(
270+
loopStepNode ->
271+
!iterVar2.isEmpty()
272+
&& loopStepNode.getKind().equals(Kind.IDENT)
273+
&& loopStepNode.expr().ident().name().equals(iterVar2))
274+
.map(CelNavigableMutableExpr::id)
275+
.findAny();
256276
Optional<CelType> iterVarType =
257277
iterVarId.map(
258278
id ->
@@ -264,6 +284,17 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
264284
"Checked type not present for iteration"
265285
+ " variable: "
266286
+ iterVarId)));
287+
Optional<CelType> iterVar2Type =
288+
iterVar2Id.map(
289+
id ->
290+
navigableMutableAst
291+
.getType(id)
292+
.orElseThrow(
293+
() ->
294+
new NoSuchElementException(
295+
"Checked type not present for iteration"
296+
+ " variable: "
297+
+ iterVar2Id)));
267298
CelType resultType =
268299
navigableMutableAst
269300
.getType(comprehension.result().id())
@@ -273,7 +304,7 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
273304
"Result type was not present for the comprehension ID: "
274305
+ comprehension.result().id()));
275306

276-
return MangledComprehensionType.of(iterVarType, resultType);
307+
return MangledComprehensionType.of(iterVarType, iterVar2Type, resultType);
277308
},
278309
(x, y) -> {
279310
throw new IllegalStateException(
@@ -299,19 +330,22 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
299330
MangledComprehensionName mangledComprehensionName =
300331
getMangledComprehensionName(
301332
newIterVarPrefix,
333+
newIterVar2Prefix,
302334
newAccuVarPrefix,
303335
comprehensionNode,
304336
comprehensionLevelToType,
305337
comprehensionEntryType);
306338
mangledIdentNamesToType.put(mangledComprehensionName, comprehensionEntryType);
307339

308340
String iterVar = comprehensionExpr.comprehension().iterVar();
341+
String iterVar2 = comprehensionExpr.comprehension().iterVar2();
309342
String accuVar = comprehensionExpr.comprehension().accuVar();
310343
mutatedComprehensionExpr =
311344
mangleIdentsInComprehensionExpr(
312345
mutatedComprehensionExpr,
313346
comprehensionExpr,
314347
iterVar,
348+
iterVar2,
315349
accuVar,
316350
mangledComprehensionName);
317351
// Repeat the mangling process for the macro source.
@@ -320,6 +354,7 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
320354
newSource,
321355
mutatedComprehensionExpr,
322356
iterVar,
357+
iterVar2,
323358
mangledComprehensionName,
324359
comprehensionExpr.id());
325360
iterCount++;
@@ -339,6 +374,7 @@ public MangledComprehensionAst mangleComprehensionIdentifierNames(
339374

340375
private static MangledComprehensionName getMangledComprehensionName(
341376
String newIterVarPrefix,
377+
String newIterVar2Prefix,
342378
String newResultPrefix,
343379
CelNavigableMutableExpr comprehensionNode,
344380
Table<Integer, MangledComprehensionType, MangledComprehensionName> comprehensionLevelToType,
@@ -356,7 +392,11 @@ private static MangledComprehensionName getMangledComprehensionName(
356392
newIterVarPrefix + ":" + comprehensionNestingLevel + ":" + uniqueTypeIdx;
357393
String mangledResultName =
358394
newResultPrefix + ":" + comprehensionNestingLevel + ":" + uniqueTypeIdx;
359-
mangledComprehensionName = MangledComprehensionName.of(mangledIterVarName, mangledResultName);
395+
String mangledIterVar2Name =
396+
newIterVar2Prefix + ":" + comprehensionNestingLevel + ":" + uniqueTypeIdx;
397+
398+
mangledComprehensionName =
399+
MangledComprehensionName.of(mangledIterVarName, mangledIterVar2Name, mangledResultName);
360400
comprehensionLevelToType.put(
361401
comprehensionNestingLevel, comprehensionEntryType, mangledComprehensionName);
362402
}
@@ -509,6 +549,7 @@ private CelMutableExpr mangleIdentsInComprehensionExpr(
509549
CelMutableExpr root,
510550
CelMutableExpr comprehensionExpr,
511551
String originalIterVar,
552+
String originalIterVar2,
512553
String originalAccuVar,
513554
MangledComprehensionName mangledComprehensionName) {
514555
CelMutableComprehension comprehension = comprehensionExpr.comprehension();
@@ -517,11 +558,18 @@ private CelMutableExpr mangleIdentsInComprehensionExpr(
517558
replaceIdentName(comprehensionExpr, originalAccuVar, mangledComprehensionName.resultName());
518559

519560
comprehension.setIterVar(mangledComprehensionName.iterVarName());
561+
520562
// Most standard macros set accu_var as __result__, but not all (ex: cel.bind).
521563
if (comprehension.accuVar().equals(originalAccuVar)) {
522564
comprehension.setAccuVar(mangledComprehensionName.resultName());
523565
}
524566

567+
if (!originalIterVar2.isEmpty()) {
568+
comprehension.setIterVar2(mangledComprehensionName.iterVar2Name());
569+
replaceIdentName(
570+
comprehension.loopStep(), originalIterVar2, mangledComprehensionName.iterVar2Name());
571+
}
572+
525573
return mutateExpr(NO_OP_ID_GENERATOR, root, comprehensionExpr, comprehensionExpr.id());
526574
}
527575

@@ -560,6 +608,7 @@ private CelMutableSource mangleIdentsInMacroSource(
560608
CelMutableSource sourceBuilder,
561609
CelMutableExpr mutatedComprehensionExpr,
562610
String originalIterVar,
611+
String originalIterVar2,
563612
MangledComprehensionName mangledComprehensionName,
564613
long originalComprehensionId) {
565614
if (!sourceBuilder.getMacroCalls().containsKey(originalComprehensionId)) {
@@ -583,14 +632,25 @@ private CelMutableSource mangleIdentsInMacroSource(
583632
// macro call expression.
584633
CelMutableExpr identToMangle = macroExpr.call().args().get(0);
585634
if (identToMangle.ident().name().equals(originalIterVar)) {
586-
// if (identToMangle.identOrDefault().name().equals(originalIterVar)) {
587635
macroExpr =
588636
mutateExpr(
589637
NO_OP_ID_GENERATOR,
590638
macroExpr,
591639
CelMutableExpr.ofIdent(mangledComprehensionName.iterVarName()),
592640
identToMangle.id());
593641
}
642+
if (!originalIterVar2.isEmpty()) {
643+
// Similarly by convention, iter_var2 is always the second argument of the macro call.
644+
identToMangle = macroExpr.call().args().get(1);
645+
if (identToMangle.ident().name().equals(originalIterVar2)) {
646+
macroExpr =
647+
mutateExpr(
648+
NO_OP_ID_GENERATOR,
649+
macroExpr,
650+
CelMutableExpr.ofIdent(mangledComprehensionName.iterVar2Name()),
651+
identToMangle.id());
652+
}
653+
}
594654

595655
newSource.addMacroCalls(originalComprehensionId, macroExpr);
596656

@@ -794,7 +854,7 @@ private static void unwrapListArgumentsInMacroCallExpr(
794854
newMacroCall.addArgs(
795855
existingMacroCall.args().get(0)); // iter_var is first argument of the call by convention
796856

797-
CelMutableList extraneousList = null;
857+
CelMutableList extraneousList;
798858
if (loopStepArgs.size() == 2) {
799859
extraneousList = loopStepArgs.get(1).list();
800860
} else {
@@ -874,14 +934,22 @@ private static MangledComprehensionAst of(
874934
@AutoValue
875935
public abstract static class MangledComprehensionType {
876936

877-
/** Type of iter_var */
937+
/**
938+
* Type of iter_var. Empty if iter_var is not referenced in the expression anywhere (ex: "i" in
939+
* "[1].exists(i, true)"
940+
*/
878941
public abstract Optional<CelType> iterVarType();
879942

943+
/** Type of iter_var2. */
944+
public abstract Optional<CelType> iterVar2Type();
945+
880946
/** Type of comprehension result */
881947
public abstract CelType resultType();
882948

883-
private static MangledComprehensionType of(Optional<CelType> iterVarType, CelType resultType) {
884-
return new AutoValue_AstMutator_MangledComprehensionType(iterVarType, resultType);
949+
private static MangledComprehensionType of(
950+
Optional<CelType> iterVarType, Optional<CelType> iterVarType2, CelType resultType) {
951+
return new AutoValue_AstMutator_MangledComprehensionType(
952+
iterVarType, iterVarType2, resultType);
885953
}
886954
}
887955

@@ -895,11 +963,16 @@ public abstract static class MangledComprehensionName {
895963
/** Mangled name for iter_var */
896964
public abstract String iterVarName();
897965

966+
/** Mangled name for iter_var2 */
967+
public abstract String iterVar2Name();
968+
898969
/** Mangled name for comprehension result */
899970
public abstract String resultName();
900971

901-
private static MangledComprehensionName of(String iterVarName, String resultName) {
902-
return new AutoValue_AstMutator_MangledComprehensionName(iterVarName, resultName);
972+
private static MangledComprehensionName of(
973+
String iterVarName, String iterVar2Name, String resultName) {
974+
return new AutoValue_AstMutator_MangledComprehensionName(
975+
iterVarName, iterVar2Name, resultName);
903976
}
904977
}
905978
}

optimizer/src/main/java/dev/cel/optimizer/optimizers/SubexpressionOptimizer.java

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ public class SubexpressionOptimizer implements CelAstOptimizer {
9191
new SubexpressionOptimizer(SubexpressionOptimizerOptions.newBuilder().build());
9292
private static final String BIND_IDENTIFIER_PREFIX = "@r";
9393
private static final String MANGLED_COMPREHENSION_ITER_VAR_PREFIX = "@it";
94+
private static final String MANGLED_COMPREHENSION_ITER_VAR2_PREFIX = "@it2";
9495
private static final String MANGLED_COMPREHENSION_ACCU_VAR_PREFIX = "@ac";
9596
private static final String CEL_BLOCK_FUNCTION = "cel.@block";
9697
private static final String BLOCK_INDEX_PREFIX = "@index";
@@ -136,6 +137,7 @@ private OptimizationResult optimizeUsingCelBlock(CelAbstractSyntaxTree ast, Cel
136137
astMutator.mangleComprehensionIdentifierNames(
137138
astToModify,
138139
MANGLED_COMPREHENSION_ITER_VAR_PREFIX,
140+
MANGLED_COMPREHENSION_ITER_VAR2_PREFIX,
139141
MANGLED_COMPREHENSION_ACCU_VAR_PREFIX);
140142
astToModify = mangledComprehensionAst.mutableAst();
141143
CelMutableSource sourceToModify = astToModify.source();
@@ -196,6 +198,12 @@ private OptimizationResult optimizeUsingCelBlock(CelAbstractSyntaxTree ast, Cel
196198
iterVarType ->
197199
newVarDecls.add(
198200
CelVarDecl.newVarDeclaration(name.iterVarName(), iterVarType)));
201+
type.iterVar2Type()
202+
.ifPresent(
203+
iterVar2Type ->
204+
newVarDecls.add(
205+
CelVarDecl.newVarDeclaration(name.iterVar2Name(), iterVar2Type)));
206+
199207
newVarDecls.add(CelVarDecl.newVarDeclaration(name.resultName(), type.resultType()));
200208
});
201209

@@ -445,16 +453,16 @@ private boolean containsComprehensionIdentInSubexpr(CelNavigableMutableExpr navE
445453
navExpr
446454
.allNodes()
447455
.filter(
448-
node ->
449-
node.getKind().equals(Kind.IDENT)
450-
&& (node.expr()
451-
.ident()
452-
.name()
453-
.startsWith(MANGLED_COMPREHENSION_ITER_VAR_PREFIX)
454-
|| node.expr()
455-
.ident()
456-
.name()
457-
.startsWith(MANGLED_COMPREHENSION_ACCU_VAR_PREFIX)))
456+
node -> {
457+
if (!node.getKind().equals(Kind.IDENT)) {
458+
return false;
459+
}
460+
461+
String identName = node.expr().ident().name();
462+
return identName.startsWith(MANGLED_COMPREHENSION_ITER_VAR_PREFIX)
463+
|| identName.startsWith(MANGLED_COMPREHENSION_ITER_VAR2_PREFIX)
464+
|| identName.startsWith(MANGLED_COMPREHENSION_ACCU_VAR_PREFIX);
465+
})
458466
.collect(toImmutableList());
459467

460468
if (comprehensionIdents.isEmpty()) {

0 commit comments

Comments
 (0)