Skip to content

Commit 7737ed1

Browse files
jnthntatumcopybara-github
authored andcommitted
Add is_contextual flag to FunctionDescriptor.
This is used to mark a function as impure or context dependent. This blocks constant folding from attempting to evaluate the function. PiperOrigin-RevId: 827691703
1 parent dd03e07 commit 7737ed1

File tree

7 files changed

+226
-34
lines changed

7 files changed

+226
-34
lines changed

common/function_descriptor.h

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,30 +26,68 @@
2626

2727
namespace cel {
2828

29+
struct FunctionDescriptorOptions {
30+
// If true (strict, default), error or unknown arguments are propagated
31+
// instead of calling the function. if false (non-strict), the function may
32+
// receive error or unknown values as arguments.
33+
bool is_strict = true;
34+
35+
// Whether the function is impure or context-sensitive.
36+
//
37+
// Impure functions depend on state other than the arguments received during
38+
// the CEL expression evaluation or have visible side effects. This breaks
39+
// some of the assumptions of the CEL evaluation model. This flag is used as a
40+
// hint to the planner that some optimizations are not safe or not effective.
41+
bool is_contextual = false;
42+
};
43+
2944
// Coarsely describes a function for the purpose of runtime resolution of
3045
// overloads.
3146
class FunctionDescriptor final {
3247
public:
3348
FunctionDescriptor(absl::string_view name, bool receiver_style,
34-
std::vector<Kind> types, bool is_strict = true)
35-
: impl_(std::make_shared<Impl>(name, receiver_style, std::move(types),
36-
is_strict)) {}
49+
std::vector<Kind> types, bool is_strict)
50+
: impl_(std::make_shared<Impl>(
51+
name, std::move(types), receiver_style,
52+
FunctionDescriptorOptions{is_strict,
53+
/*is_contextual=*/false})) {}
54+
55+
FunctionDescriptor(absl::string_view name, bool receiver_style,
56+
std::vector<Kind> types, bool is_strict,
57+
bool is_contextual)
58+
: impl_(std::make_shared<Impl>(
59+
name, std::move(types), receiver_style,
60+
FunctionDescriptorOptions{is_strict, is_contextual})) {}
61+
62+
FunctionDescriptor(absl::string_view name, bool is_receiver_style,
63+
std::vector<Kind> types,
64+
FunctionDescriptorOptions options = {})
65+
: impl_(std::make_shared<Impl>(name, std::move(types), is_receiver_style,
66+
options)) {}
3767

3868
// Function name.
3969
const std::string& name() const { return impl_->name; }
4070

4171
// Whether function is receiver style i.e. true means arg0.name(args[1:]...).
42-
bool receiver_style() const { return impl_->receiver_style; }
72+
bool receiver_style() const { return impl_->is_receiver_style; }
4373

44-
// The argmument types the function accepts.
74+
// The argument types the function accepts.
4575
//
4676
// TODO(uncreated-issue/17): make this kinds
4777
const std::vector<Kind>& types() const { return impl_->types; }
4878

4979
// if true (strict, default), error or unknown arguments are propagated
5080
// instead of calling the function. if false (non-strict), the function may
5181
// receive error or unknown values as arguments.
52-
bool is_strict() const { return impl_->is_strict; }
82+
bool is_strict() const { return impl_->options.is_strict; }
83+
84+
// Whether the function is contextual (impure).
85+
//
86+
// Contextual functions depend on state other than the arguments received in
87+
// the CEL expression evaluation or have visible side effects. This breaks
88+
// some of the assumptions of CEL. This flag is used as a hint to the planner
89+
// that some optimizations are not safe or not effective.
90+
bool is_contextual() const { return impl_->options.is_contextual; }
5391

5492
// Helper for matching a descriptor. This tests that the shape is the same --
5593
// |other| accepts the same number and types of arguments and is the same call
@@ -65,17 +103,17 @@ class FunctionDescriptor final {
65103

66104
private:
67105
struct Impl final {
68-
Impl(absl::string_view name, bool receiver_style, std::vector<Kind> types,
69-
bool is_strict)
106+
Impl(absl::string_view name, std::vector<Kind> types,
107+
bool is_receiver_style, FunctionDescriptorOptions options)
70108
: name(name),
71109
types(std::move(types)),
72-
receiver_style(receiver_style),
73-
is_strict(is_strict) {}
110+
is_receiver_style(is_receiver_style),
111+
options(options) {}
74112

75113
std::string name;
76114
std::vector<Kind> types;
77-
bool receiver_style;
78-
bool is_strict;
115+
bool is_receiver_style;
116+
FunctionDescriptorOptions options;
79117
};
80118

81119
std::shared_ptr<const Impl> impl_;

eval/compiler/constant_folding.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,17 @@ IsConst IsConstExpr(const Expr& expr, const Resolver& resolver) {
155155
return IsConst::kNonConst;
156156
}
157157

158+
auto overloads =
159+
resolver.FindOverloads(call.function(), call.has_target(), arg_len);
160+
// Check for any contextual overloads. If there are any, we cowardly
161+
// avoid constant folding instead of trying to check if one of the
162+
// overloads would be safe to use.
163+
for (const auto& overload : overloads) {
164+
if (overload.descriptor.is_contextual()) {
165+
return IsConst::kNonConst;
166+
}
167+
}
168+
158169
return IsConst::kConditional;
159170
}
160171
case ExprKindCase::kUnspecifiedExpr:

runtime/BUILD

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ cc_library(
116116
deps =
117117
[
118118
":function_registry",
119+
"//common:function_descriptor",
119120
"@com_google_absl//absl/status",
120121
"@com_google_absl//absl/strings",
121122
],
@@ -320,7 +321,7 @@ cc_library(
320321
deps = [
321322
":runtime",
322323
":runtime_builder",
323-
"//common:native_type",
324+
"//common:typeinfo",
324325
"//eval/compiler:constant_folding",
325326
"//internal:casts",
326327
"//internal:noop_delete",
@@ -342,11 +343,14 @@ cc_test(
342343
deps = [
343344
":activation",
344345
":constant_folding",
346+
":function",
345347
":register_function_helper",
346348
":runtime_builder",
347349
":runtime_options",
348350
":standard_runtime_builder_factory",
349351
"//base:function_adapter",
352+
"//common:function_descriptor",
353+
"//common:kind",
350354
"//common:value",
351355
"//extensions/protobuf:runtime_adapter",
352356
"//internal:testing",

runtime/constant_folding.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#include "absl/log/absl_check.h"
2323
#include "absl/status/status.h"
2424
#include "absl/status/statusor.h"
25-
#include "common/native_type.h"
25+
#include "common/typeinfo.h"
2626
#include "eval/compiler/constant_folding.h"
2727
#include "internal/casts.h"
2828
#include "internal/noop_delete.h"
@@ -44,8 +44,7 @@ using ::cel::runtime_internal::RuntimeImpl;
4444
absl::StatusOr<RuntimeImpl* absl_nonnull> RuntimeImplFromBuilder(
4545
RuntimeBuilder& builder ABSL_ATTRIBUTE_LIFETIME_BOUND) {
4646
Runtime& runtime = RuntimeFriendAccess::GetMutableRuntime(builder);
47-
if (RuntimeFriendAccess::RuntimeTypeId(runtime) !=
48-
NativeTypeId::For<RuntimeImpl>()) {
47+
if (RuntimeFriendAccess::RuntimeTypeId(runtime) != TypeId<RuntimeImpl>()) {
4948
return absl::UnimplementedError(
5049
"constant folding only supported on the default cel::Runtime "
5150
"implementation.");

runtime/constant_folding_test.cc

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "runtime/constant_folding.h"
1616

17+
#include <memory>
1718
#include <string>
1819
#include <utility>
1920
#include <vector>
@@ -25,13 +26,13 @@
2526
#include "absl/strings/match.h"
2627
#include "absl/strings/str_cat.h"
2728
#include "base/function_adapter.h"
29+
#include "common/function_descriptor.h"
2830
#include "common/value.h"
2931
#include "extensions/protobuf/runtime_adapter.h"
3032
#include "internal/testing.h"
3133
#include "internal/testing_descriptor_pool.h"
3234
#include "parser/parser.h"
3335
#include "runtime/activation.h"
34-
#include "runtime/register_function_helper.h"
3536
#include "runtime/runtime_builder.h"
3637
#include "runtime/runtime_options.h"
3738
#include "runtime/standard_runtime_builder_factory.h"
@@ -82,8 +83,8 @@ TEST_P(ConstantFoldingExtTest, Runner) {
8283
CreateStandardRuntimeBuilder(
8384
internal::GetTestingDescriptorPool(), options));
8485

85-
auto status = RegisterHelper<BinaryFunctionAdapter<
86-
absl::StatusOr<Value>, const StringValue&, const StringValue&>>::
86+
auto status = BinaryFunctionAdapter<absl::StatusOr<Value>, const StringValue&,
87+
const StringValue&>::
8788
RegisterGlobalOverload(
8889
"prepend",
8990
[](const StringValue& value, const StringValue& prefix) {
@@ -129,14 +130,99 @@ INSTANTIATE_TEST_SUITE_P(
129130
IsBoolValue(true)},
130131
{"runtime_error", "[1, 2, 3, 4].exists(x, ['4'].all(y, y <= x))",
131132
IsErrorValue("No matching overloads")},
132-
// TODO(uncreated-issue/32): Depends on map creation
133-
// {"map_create", "{'abc': 'def', 'abd': 'deg'}.size()", 2},
133+
{"map_create", "{'abc': 'def', 'abd': 'deg'}.size()", IsIntValue(2)},
134134
{"custom_function", "prepend('def', 'abc') == 'abcdef'",
135135
IsBoolValue(true)}}),
136136

137137
[](const testing::TestParamInfo<TestCase>& info) {
138138
return info.param.name;
139139
});
140140

141+
TEST(ConstantFoldingExtTest, LazyFunctionNotFolded) {
142+
google::protobuf::Arena arena;
143+
RuntimeOptions options;
144+
145+
ASSERT_OK_AND_ASSIGN(cel::RuntimeBuilder builder,
146+
CreateStandardRuntimeBuilder(
147+
internal::GetTestingDescriptorPool(), options));
148+
int call_count = 0;
149+
using FunctionAdapter =
150+
BinaryFunctionAdapter<absl::StatusOr<Value>, const StringValue&,
151+
const StringValue&>;
152+
auto fn = FunctionAdapter::WrapFunction(
153+
[&call_count](const StringValue& value, const StringValue& prefix) {
154+
call_count++;
155+
return StringValue(absl::StrCat(prefix.ToString(), value.ToString()));
156+
});
157+
FunctionDescriptor descriptor = FunctionAdapter::CreateDescriptor(
158+
"lazy_prepend", /*receiver_style=*/false);
159+
ASSERT_THAT(builder.function_registry().RegisterLazyFunction(descriptor),
160+
IsOk());
161+
162+
ASSERT_THAT(EnableConstantFolding(builder), IsOk());
163+
164+
ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build());
165+
166+
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr,
167+
Parse("lazy_prepend('def', 'abc') == 'abcdef'"));
168+
169+
ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram(
170+
*runtime, parsed_expr));
171+
EXPECT_EQ(call_count, 0);
172+
Activation activation;
173+
activation.InsertFunction(descriptor, std::move(fn));
174+
175+
ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation));
176+
EXPECT_EQ(call_count, 1);
177+
EXPECT_THAT(result, IsBoolValue(true));
178+
179+
ASSERT_OK_AND_ASSIGN(result, program->Evaluate(&arena, activation));
180+
EXPECT_EQ(call_count, 2);
181+
EXPECT_THAT(result, IsBoolValue(true));
182+
}
183+
184+
TEST(ConstantFoldingExtTest, ContextualFunctionNotFolded) {
185+
google::protobuf::Arena arena;
186+
RuntimeOptions options;
187+
ASSERT_OK_AND_ASSIGN(cel::RuntimeBuilder builder,
188+
CreateStandardRuntimeBuilder(
189+
internal::GetTestingDescriptorPool(), options));
190+
int call_count = 0;
191+
192+
auto status = BinaryFunctionAdapter<
193+
absl::StatusOr<Value>, const StringValue&,
194+
const StringValue&>::Register("contextual_prepend",
195+
/*receiver_style=*/false,
196+
[&call_count](const StringValue& value,
197+
const StringValue& prefix) {
198+
call_count++;
199+
return StringValue(absl::StrCat(
200+
prefix.ToString(), value.ToString()));
201+
},
202+
builder.function_registry(),
203+
{/*.is_strict=*/true,
204+
/*is_contextual=*/true});
205+
ASSERT_THAT(status, IsOk());
206+
207+
ASSERT_THAT(EnableConstantFolding(builder), IsOk());
208+
209+
ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build());
210+
211+
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr,
212+
Parse("contextual_prepend('def', 'abc') == 'abcdef'"));
213+
214+
ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram(
215+
*runtime, parsed_expr));
216+
EXPECT_EQ(call_count, 0);
217+
Activation activation;
218+
ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation));
219+
EXPECT_EQ(call_count, 1);
220+
EXPECT_THAT(value, IsBoolValue(true));
221+
222+
ASSERT_OK_AND_ASSIGN(value, program->Evaluate(&arena, activation));
223+
EXPECT_EQ(call_count, 2);
224+
EXPECT_THAT(value, IsBoolValue(true));
225+
}
226+
141227
} // namespace
142228
} // namespace cel::extensions

0 commit comments

Comments
 (0)