Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions extensions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -712,10 +712,16 @@ cc_test(
":regex_ext",
"//checker:standard_library",
"//checker:validation_result",
"//common:kind",
"//common:value",
"//common:value_testing",
"//compiler",
"//compiler:compiler_factory",
"//eval/public:activation",
"//eval/public:cel_expr_builder_factory",
"//eval/public:cel_expression",
"//eval/public:cel_function_registry",
"//eval/public:cel_options",
"//extensions/protobuf:runtime_adapter",
"//internal:status_macros",
"//internal:testing",
Expand Down
19 changes: 12 additions & 7 deletions extensions/regex_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,13 @@ Value ReplaceN(const StringValue& target, const StringValue& regex,
return StringValue::From(std::move(output), arena);
}

absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry) {
CEL_RETURN_IF_ERROR(
(BinaryFunctionAdapter<absl::StatusOr<Value>, StringValue, StringValue>::
RegisterGlobalOverload("regex.extract", &Extract, registry)));
absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry,
bool disable_extract) {
if (!disable_extract) {
CEL_RETURN_IF_ERROR((
BinaryFunctionAdapter<absl::StatusOr<Value>, StringValue, StringValue>::
RegisterGlobalOverload("regex.extract", &Extract, registry)));
}
CEL_RETURN_IF_ERROR(
(BinaryFunctionAdapter<absl::StatusOr<Value>, StringValue, StringValue>::
RegisterGlobalOverload("regex.extractAll", &ExtractAll, registry)));
Expand Down Expand Up @@ -306,16 +309,18 @@ absl::Status RegisterRegexExtensionFunctions(RuntimeBuilder& builder) {
}
if (runtime.expr_builder().options().enable_regex) {
CEL_RETURN_IF_ERROR(
RegisterRegexExtensionFunctions(builder.function_registry()));
RegisterRegexExtensionFunctions(builder.function_registry(),
/*disable_extract=*/false));
}
return absl::OkStatus();
}

absl::Status RegisterRegexExtensionFunctions(
google::api::expr::runtime::CelFunctionRegistry* registry,
const google::api::expr::runtime::InterpreterOptions& options) {
if (!options.enable_regex) {
return RegisterRegexExtensionFunctions(registry->InternalGetRegistry());
if (options.enable_regex) {
return RegisterRegexExtensionFunctions(registry->InternalGetRegistry(),
/*disable_extract=*/true);
}
return absl::OkStatus();
}
Expand Down
8 changes: 7 additions & 1 deletion extensions/regex_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,16 @@

namespace cel::extensions {

// Register extension functions for regular expressions.
// Register extension functions for regular expressions for
// google::api::expr::runtime::CelValue runtime.
//
// Note: CelValue does not support optional types, so regex.extract is
// unsupported.
absl::Status RegisterRegexExtensionFunctions(
google::api::expr::runtime::CelFunctionRegistry* registry,
const google::api::expr::runtime::InterpreterOptions& options);

// Register extension functions for regular expressions.
absl::Status RegisterRegexExtensionFunctions(RuntimeBuilder& builder);

// Type check declarations for the regex extension library.
Expand Down
135 changes: 114 additions & 21 deletions extensions/regex_ext_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,16 @@
#include "absl/strings/string_view.h"
#include "checker/standard_library.h"
#include "checker/validation_result.h"
#include "common/kind.h"
#include "common/value.h"
#include "common/value_testing.h"
#include "compiler/compiler.h"
#include "compiler/compiler_factory.h"
#include "eval/public/activation.h"
#include "eval/public/cel_expr_builder_factory.h"
#include "eval/public/cel_expression.h"
#include "eval/public/cel_function_registry.h"
#include "eval/public/cel_options.h"
#include "extensions/protobuf/runtime_adapter.h"
#include "internal/status_macros.h"
#include "internal/testing.h"
Expand All @@ -49,16 +55,120 @@ namespace {
using ::absl_testing::IsOk;
using ::absl_testing::IsOkAndHolds;
using ::absl_testing::StatusIs;
using ::cel::test::BoolValueIs;
using ::cel::test::ErrorValueIs;
using ::cel::test::OptionalValueIs;
using ::cel::test::OptionalValueIsEmpty;
using ::cel::test::StringValueIs;
using ::google::api::expr::parser::Parse;
using test::BoolValueIs;
using test::OptionalValueIs;
using test::OptionalValueIsEmpty;
using test::StringValueIs;
using ::google::api::expr::runtime::CelExpressionBuilder;
using ::google::api::expr::runtime::CelFunctionRegistry;
using ::google::api::expr::runtime::CreateCelExpressionBuilder;
using ::google::api::expr::runtime::InterpreterOptions;
using ::testing::HasSubstr;
using ::testing::IsEmpty;
using ::testing::SizeIs;
using ::testing::TestWithParam;
using ::testing::ValuesIn;

using LegacyActivation = google::api::expr::runtime::Activation;

TEST(RegexExtTest, BuildFailsWithoutOptionalSupport) {
RuntimeOptions options;
options.enable_regex = true;
options.enable_qualified_type_identifiers = true;

ASSERT_OK_AND_ASSIGN(auto builder,
CreateStandardRuntimeBuilder(
internal::GetTestingDescriptorPool(), options));
ASSERT_THAT(
EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways),
IsOk());
// Optional types are NOT enabled.
ASSERT_THAT(RegisterRegexExtensionFunctions(builder),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("regex extensions requires the optional types "
"to be enabled")));
}

TEST(RegexExtTest, LegacyRuntimeSmokeTest) {
InterpreterOptions options;
options.enable_regex = true;
options.enable_qualified_type_identifiers = true;
options.enable_qualified_identifier_rewrites = true;

std::unique_ptr<CelExpressionBuilder> builder = CreateCelExpressionBuilder(
internal::GetTestingDescriptorPool(), nullptr, options);

// Optional types are NOT enabled.
ASSERT_THAT(RegisterRegexExtensionFunctions(builder->GetRegistry(), options),
IsOk());

ASSERT_OK_AND_ASSIGN(auto expr,
Parse("regex.extractAll('hello world', 'hello (.*)')"));
LegacyActivation activation;
google::protobuf::Arena arena;
ASSERT_OK_AND_ASSIGN(auto program, builder->CreateExpression(
&expr.expr(), &expr.source_info()));
ASSERT_OK_AND_ASSIGN(auto result, program->Evaluate(activation, &arena));
ASSERT_TRUE(result.IsList());
ASSERT_EQ(result.ListOrDie()->size(), 1);
ASSERT_TRUE(result.ListOrDie()->Get(&arena, 0).IsString());
EXPECT_EQ(result.ListOrDie()->Get(&arena, 0).StringOrDie().value(), "world");
}

TEST(RegexExtTest, DoesNotRegisterExtractForLegacy) {
InterpreterOptions options;
options.enable_regex = true;

CelFunctionRegistry registry;
// Optional types are not usable in legacy runtime, so extract should not be
// registered.
ASSERT_THAT(RegisterRegexExtensionFunctions(&registry, options), IsOk());
EXPECT_THAT(
registry.FindStaticOverloads("regex.extract", false,
{cel::Kind::kString, cel::Kind::kString}),
IsEmpty());
EXPECT_THAT(
registry.FindStaticOverloads("regex.extractAll", false,
{cel::Kind::kString, cel::Kind::kString}),
SizeIs(1));
EXPECT_THAT(registry.FindStaticOverloads(
"regex.replace", false,
{cel::Kind::kString, cel::Kind::kString, cel::Kind::kString}),
SizeIs(1));
EXPECT_THAT(
registry.FindStaticOverloads("regex.replace", false,
{cel::Kind::kString, cel::Kind::kString,
cel::Kind::kString, cel::Kind::kInt64}),
SizeIs(1));
}

TEST(RegexExtTest, FollowsRegexOption) {
InterpreterOptions options;
options.enable_regex = false;

CelFunctionRegistry registry;
ASSERT_THAT(RegisterRegexExtensionFunctions(&registry, options), IsOk());
EXPECT_THAT(
registry.FindStaticOverloads("regex.extract", false,
{cel::Kind::kString, cel::Kind::kString}),
IsEmpty());
EXPECT_THAT(
registry.FindStaticOverloads("regex.extractAll", false,
{cel::Kind::kString, cel::Kind::kString}),
IsEmpty());
EXPECT_THAT(registry.FindStaticOverloads(
"regex.replace", false,
{cel::Kind::kString, cel::Kind::kString, cel::Kind::kString}),
IsEmpty());
EXPECT_THAT(
registry.FindStaticOverloads("regex.replace", false,
{cel::Kind::kString, cel::Kind::kString,
cel::Kind::kString, cel::Kind::kInt64}),
IsEmpty());
}

enum class EvaluationType {
kBoolTrue,
kOptionalValue,
Expand Down Expand Up @@ -105,23 +215,6 @@ class RegexExtTest : public TestWithParam<RegexExtTestCase> {
std::unique_ptr<const Runtime> runtime_;
};

TEST_F(RegexExtTest, BuildFailsWithoutOptionalSupport) {
RuntimeOptions options;
options.enable_regex = true;
options.enable_qualified_type_identifiers = true;

ASSERT_OK_AND_ASSIGN(auto builder,
CreateStandardRuntimeBuilder(
internal::GetTestingDescriptorPool(), options));
ASSERT_THAT(
EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways),
IsOk());
// Optional types are NOT enabled.
ASSERT_THAT(RegisterRegexExtensionFunctions(builder),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("regex extensions requires the optional types "
"to be enabled")));
}
std::vector<RegexExtTestCase> regexTestCases() {
return {
// Tests for extract Function
Expand Down