Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions checker/internal/type_checker_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1329,8 +1329,9 @@ absl::StatusOr<ValidationResult> TypeCheckerImpl::Check(
// Happens in a second pass to simplify validating that pointers haven't
// been invalidated by other updates.
ResolveRewriter rewriter(visitor, type_inference_context, options_,
ast_impl.reference_map(), ast_impl.type_map());
AstRewrite(ast_impl.root_expr(), rewriter);
ast_impl.mutable_reference_map(),
ast_impl.mutable_type_map());
AstRewrite(ast_impl.mutable_root_expr(), rewriter);

CEL_RETURN_IF_ERROR(rewriter.status());

Expand Down
29 changes: 15 additions & 14 deletions checker/internal/type_checker_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,8 @@ TEST_P(PrimitiveLiteralsTest, LiteralsTypeInferred) {
ASSERT_TRUE(result.IsValid());
ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst());
auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast);
EXPECT_EQ(ast_impl.type_map()[1].primitive(), test_case.expected_type);
EXPECT_EQ(ast_impl.mutable_type_map()[1].primitive(),
test_case.expected_type);
}

INSTANTIATE_TEST_SUITE_P(
Expand Down Expand Up @@ -917,7 +918,7 @@ TEST_P(AstTypeConversionTest, TypeConversion) {
ASSERT_TRUE(result.IsValid());
ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst());
auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast);
EXPECT_EQ(ast_impl.type_map()[1], test_case.expected_type)
EXPECT_EQ(ast_impl.mutable_type_map()[1], test_case.expected_type)
<< GetParam().decl_type.DebugString();
}

Expand Down Expand Up @@ -1041,7 +1042,7 @@ TEST(TypeCheckerImplTest, NullLiteral) {
ASSERT_TRUE(result.IsValid());
ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst());
auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast);
EXPECT_TRUE(ast_impl.type_map()[1].has_null());
EXPECT_TRUE(ast_impl.mutable_type_map()[1].has_null());
}

TEST(TypeCheckerImplTest, ExpressionLimitInclusive) {
Expand Down Expand Up @@ -1114,7 +1115,7 @@ TEST(TypeCheckerImplTest, BasicOvlResolution) {
// Assumes parser numbering: + should always be id 2.
ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst());
auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast);
EXPECT_THAT(ast_impl.reference_map()[2],
EXPECT_THAT(ast_impl.mutable_reference_map()[2],
IsFunctionReference(
"_+_", std::vector<std::string>{"add_double_double"}));
}
Expand All @@ -1138,7 +1139,7 @@ TEST(TypeCheckerImplTest, OvlResolutionMultipleOverloads) {
// Assumes parser numbering: + should always be id 3.
ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst());
auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast);
EXPECT_THAT(ast_impl.reference_map()[3],
EXPECT_THAT(ast_impl.mutable_reference_map()[3],
IsFunctionReference("_+_", std::vector<std::string>{
"add_double_double", "add_int_int",
"add_list", "add_uint_uint"}));
Expand All @@ -1164,14 +1165,14 @@ TEST(TypeCheckerImplTest, BasicFunctionResultTypeResolution) {
// Assumes parser numbering: + should always be id 2 and 4.
ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst());
auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast);
EXPECT_THAT(ast_impl.reference_map()[2],
EXPECT_THAT(ast_impl.mutable_reference_map()[2],
IsFunctionReference(
"_+_", std::vector<std::string>{"add_double_double"}));
EXPECT_THAT(ast_impl.reference_map()[4],
EXPECT_THAT(ast_impl.mutable_reference_map()[4],
IsFunctionReference(
"_+_", std::vector<std::string>{"add_double_double"}));
int64_t root_id = ast_impl.root_expr().id();
EXPECT_EQ(ast_impl.type_map()[root_id].primitive(),
EXPECT_EQ(ast_impl.mutable_type_map()[root_id].primitive(),
ast_internal::PrimitiveType::kDouble);
}

Expand Down Expand Up @@ -1335,7 +1336,7 @@ TEST(TypeCheckerImplTest, BadSourcePosition) {
TypeCheckerImpl impl(std::move(env));
ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("foo"));
auto& ast_impl = AstImpl::CastFromPublicAst(*ast);
ast_impl.source_info().mutable_positions()[1] = -42;
ast_impl.mutable_source_info().mutable_positions()[1] = -42;
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));
ASSERT_OK_AND_ASSIGN(auto source, NewSource("foo"));

Expand Down Expand Up @@ -1365,7 +1366,7 @@ TEST(TypeCheckerImplTest, FailsIfNoTypeDeduced) {
// Assume that an unspecified expr kind is not deducible.
Expr unspecified_expr;
unspecified_expr.set_id(3);
ast_impl.root_expr().mutable_call_expr().mutable_args()[1] =
ast_impl.mutable_root_expr().mutable_call_expr().mutable_args()[1] =
std::move(unspecified_expr);

ASSERT_THAT(impl.Check(std::move(ast)),
Expand All @@ -1382,7 +1383,7 @@ TEST(TypeCheckerImplTest, BadLineOffsets) {
{
ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("\nfoo"));
auto& ast_impl = AstImpl::CastFromPublicAst(*ast);
ast_impl.source_info().mutable_line_offsets()[1] = 1;
ast_impl.mutable_source_info().mutable_line_offsets()[1] = 1;
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));

EXPECT_FALSE(result.IsValid());
Expand All @@ -1395,9 +1396,9 @@ TEST(TypeCheckerImplTest, BadLineOffsets) {
{
ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("\nfoo"));
auto& ast_impl = AstImpl::CastFromPublicAst(*ast);
ast_impl.source_info().mutable_line_offsets().clear();
ast_impl.source_info().mutable_line_offsets().push_back(-1);
ast_impl.source_info().mutable_line_offsets().push_back(2);
ast_impl.mutable_source_info().mutable_line_offsets().clear();
ast_impl.mutable_source_info().mutable_line_offsets().push_back(-1);
ast_impl.mutable_source_info().mutable_line_offsets().push_back(2);

ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));

Expand Down
108 changes: 54 additions & 54 deletions checker/optional_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,8 @@ using ::testing::Not;
using ::testing::Property;
using ::testing::SizeIs;

using AstType = ast_internal::Type;

MATCHER_P(IsOptionalType, inner_type, "") {
const ast_internal::Type& type = arg;
const TypeSpec& type = arg;
if (!type.has_abstract_type()) {
return false;
}
Expand Down Expand Up @@ -100,13 +98,13 @@ TEST(OptionalTest, OptSelectDoesNotAnnotateFieldType) {
EXPECT_NE(field_id, 0);

EXPECT_THAT(ast_impl.type_map(), Not(Contains(Key(field_id))));
EXPECT_THAT(ast_impl.GetType(ast_impl.root_expr().id()),
IsOptionalType(AstType(ast_internal::PrimitiveType::kInt64)));
EXPECT_THAT(ast_impl.GetTypeOrDyn(ast_impl.root_expr().id()),
IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kInt64)));
}

struct TestCase {
std::string expr;
testing::Matcher<ast_internal::Type> result_type_matcher;
testing::Matcher<TypeSpec> result_type_matcher;
std::string error_substring;
};

Expand Down Expand Up @@ -144,7 +142,7 @@ TEST_P(OptionalTest, Runner) {

int64_t root_id = ast_impl.root_expr().id();

EXPECT_THAT(ast_impl.GetType(root_id), test_case.result_type_matcher)
EXPECT_THAT(ast_impl.GetTypeOrDyn(root_id), test_case.result_type_matcher)
<< "for expression: " << test_case.expr;
}

Expand All @@ -153,130 +151,132 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(
TestCase{
"optional.of('abc')",
IsOptionalType(AstType(ast_internal::PrimitiveType::kString)),
IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kString)),
},
TestCase{
"optional.ofNonZeroValue('')",
IsOptionalType(AstType(ast_internal::PrimitiveType::kString)),
IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kString)),
},
TestCase{
"optional.none()",
IsOptionalType(AstType(ast_internal::DynamicType())),
IsOptionalType(TypeSpec(ast_internal::DynamicType())),
},
TestCase{
"optional.of('abc').hasValue()",
Eq(AstType(ast_internal::PrimitiveType::kBool)),
Eq(TypeSpec(ast_internal::PrimitiveType::kBool)),
},
TestCase{
"optional.of('abc').value()",
Eq(AstType(ast_internal::PrimitiveType::kString)),
Eq(TypeSpec(ast_internal::PrimitiveType::kString)),
},
TestCase{
"type(optional.of('abc')) == optional_type",
Eq(AstType(ast_internal::PrimitiveType::kBool)),
Eq(TypeSpec(ast_internal::PrimitiveType::kBool)),
},
TestCase{
"type(optional.of('abc')) == optional_type",
Eq(AstType(ast_internal::PrimitiveType::kBool)),
Eq(TypeSpec(ast_internal::PrimitiveType::kBool)),
},
TestCase{
"optional.of('abc').or(optional.of('def'))",
IsOptionalType(AstType(ast_internal::PrimitiveType::kString)),
IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kString)),
},
TestCase{"optional.of('abc').or(optional.of(1))", _,
"no matching overload for 'or'"},
TestCase{
"optional.of('abc').orValue('def')",
Eq(AstType(ast_internal::PrimitiveType::kString)),
Eq(TypeSpec(ast_internal::PrimitiveType::kString)),
},
TestCase{"optional.of('abc').orValue(1)", _,
"no matching overload for 'orValue'"},
TestCase{
"{'k': 'v'}.?k",
IsOptionalType(AstType(ast_internal::PrimitiveType::kString)),
IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kString)),
},
TestCase{"1.?k", _,
"expression of type 'int' cannot be the operand of a select "
"operation"},
TestCase{
"{'k': {'k': 'v'}}.?k.?k2",
IsOptionalType(AstType(ast_internal::PrimitiveType::kString)),
IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kString)),
},
TestCase{
"{'k': {'k': 'v'}}.?k.k2",
IsOptionalType(AstType(ast_internal::PrimitiveType::kString)),
IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kString)),
},
TestCase{"{?'k': optional.of('v')}",
Eq(AstType(ast_internal::MapType(
std::unique_ptr<AstType>(
new AstType(ast_internal::PrimitiveType::kString)),
std::unique_ptr<AstType>(
new AstType(ast_internal::PrimitiveType::kString)))))},
Eq(TypeSpec(ast_internal::MapType(
std::unique_ptr<TypeSpec>(
new TypeSpec(ast_internal::PrimitiveType::kString)),
std::unique_ptr<TypeSpec>(new TypeSpec(
ast_internal::PrimitiveType::kString)))))},
TestCase{"{'k': 'v', ?'k2': optional.none()}",
Eq(AstType(ast_internal::MapType(
std::unique_ptr<AstType>(
new AstType(ast_internal::PrimitiveType::kString)),
std::unique_ptr<AstType>(
new AstType(ast_internal::PrimitiveType::kString)))))},
Eq(TypeSpec(ast_internal::MapType(
std::unique_ptr<TypeSpec>(
new TypeSpec(ast_internal::PrimitiveType::kString)),
std::unique_ptr<TypeSpec>(new TypeSpec(
ast_internal::PrimitiveType::kString)))))},
TestCase{"{'k': 'v', ?'k2': 'v'}", _,
"expected type 'optional_type(string)' but found 'string'"},
TestCase{"[?optional.of('v')]",
Eq(AstType(ast_internal::ListType(std::unique_ptr<AstType>(
new AstType(ast_internal::PrimitiveType::kString)))))},
Eq(TypeSpec(ast_internal::ListType(std::unique_ptr<TypeSpec>(
new TypeSpec(ast_internal::PrimitiveType::kString)))))},
TestCase{"['v', ?optional.none()]",
Eq(AstType(ast_internal::ListType(std::unique_ptr<AstType>(
new AstType(ast_internal::PrimitiveType::kString)))))},
Eq(TypeSpec(ast_internal::ListType(std::unique_ptr<TypeSpec>(
new TypeSpec(ast_internal::PrimitiveType::kString)))))},
TestCase{"['v1', ?'v2']", _,
"expected type 'optional_type(string)' but found 'string'"},
TestCase{"[optional.of(dyn('1')), optional.of('2')][0]",
IsOptionalType(AstType(ast_internal::DynamicType()))},
IsOptionalType(TypeSpec(ast_internal::DynamicType()))},
TestCase{"[optional.of('1'), optional.of(dyn('2'))][0]",
IsOptionalType(AstType(ast_internal::DynamicType()))},
IsOptionalType(TypeSpec(ast_internal::DynamicType()))},
TestCase{"[{1: optional.of(1)}, {1: optional.of(dyn(1))}][0][1]",
IsOptionalType(AstType(ast_internal::DynamicType()))},
IsOptionalType(TypeSpec(ast_internal::DynamicType()))},
TestCase{"[{1: optional.of(dyn(1))}, {1: optional.of(1)}][0][1]",
IsOptionalType(AstType(ast_internal::DynamicType()))},
IsOptionalType(TypeSpec(ast_internal::DynamicType()))},
TestCase{"[optional.of('1'), optional.of(2)][0]",
Eq(AstType(ast_internal::DynamicType()))},
Eq(TypeSpec(ast_internal::DynamicType()))},
TestCase{"['v1', ?'v2']", _,
"expected type 'optional_type(string)' but found 'string'"},
TestCase{"cel.expr.conformance.proto3.TestAllTypes{?single_int64: "
"optional.of(1)}",
Eq(AstType(ast_internal::MessageType(
Eq(TypeSpec(ast_internal::MessageType(
"cel.expr.conformance.proto3.TestAllTypes")))},
TestCase{"[0][?1]",
IsOptionalType(AstType(ast_internal::PrimitiveType::kInt64))},
IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kInt64))},
TestCase{"[[0]][?1][?1]",
IsOptionalType(AstType(ast_internal::PrimitiveType::kInt64))},
IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kInt64))},
TestCase{"[[0]][?1][1]",
IsOptionalType(AstType(ast_internal::PrimitiveType::kInt64))},
IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kInt64))},
TestCase{"{0: 1}[?1]",
IsOptionalType(AstType(ast_internal::PrimitiveType::kInt64))},
IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kInt64))},
TestCase{"{0: {0: 1}}[?1][?1]",
IsOptionalType(AstType(ast_internal::PrimitiveType::kInt64))},
IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kInt64))},
TestCase{"{0: {0: 1}}[?1][1]",
IsOptionalType(AstType(ast_internal::PrimitiveType::kInt64))},
IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kInt64))},
TestCase{"{0: {0: 1}}[?1]['']", _, "no matching overload for '_[_]'"},
TestCase{"{0: {0: 1}}[?1][?'']", _, "no matching overload for '_[?_]'"},
TestCase{"optional.of('abc').optMap(x, x + 'def')",
IsOptionalType(AstType(ast_internal::PrimitiveType::kString))},
TestCase{"optional.of('abc').optFlatMap(x, optional.of(x + 'def'))",
IsOptionalType(AstType(ast_internal::PrimitiveType::kString))},
TestCase{
"optional.of('abc').optMap(x, x + 'def')",
IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kString))},
TestCase{
"optional.of('abc').optFlatMap(x, optional.of(x + 'def'))",
IsOptionalType(TypeSpec(ast_internal::PrimitiveType::kString))},
// Legacy nullability behaviors.
TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: "
"optional.of(0)}",
Eq(AstType(ast_internal::MessageType(
Eq(TypeSpec(ast_internal::MessageType(
"cel.expr.conformance.proto3.TestAllTypes")))},
TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: null}",
Eq(AstType(ast_internal::MessageType(
Eq(TypeSpec(ast_internal::MessageType(
"cel.expr.conformance.proto3.TestAllTypes")))},
TestCase{"cel.expr.conformance.proto3.TestAllTypes{?null_value: "
"optional.of(null)}",
Eq(AstType(ast_internal::MessageType(
Eq(TypeSpec(ast_internal::MessageType(
"cel.expr.conformance.proto3.TestAllTypes")))},
TestCase{"cel.expr.conformance.proto3.TestAllTypes{}.?single_int64 "
"== null",
Eq(AstType(ast_internal::PrimitiveType::kBool))}));
Eq(TypeSpec(ast_internal::PrimitiveType::kBool))}));

class OptionalStrictNullAssignmentTest
: public testing::TestWithParam<TestCase> {};
Expand Down Expand Up @@ -315,7 +315,7 @@ TEST_P(OptionalStrictNullAssignmentTest, Runner) {

int64_t root_id = ast_impl.root_expr().id();

EXPECT_THAT(ast_impl.GetType(root_id), test_case.result_type_matcher)
EXPECT_THAT(ast_impl.GetTypeOrDyn(root_id), test_case.result_type_matcher)
<< "for expression: " << test_case.expr;
}

Expand Down
15 changes: 8 additions & 7 deletions checker/standard_library_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ TEST(StandardLibraryTest, ComprehensionResultTypeIsSubstituted) {
const ast_internal::AstImpl& checked_impl =
ast_internal::AstImpl::CastFromPublicAst(*checked_ast);

ast_internal::Type type = checked_impl.GetType(checked_impl.root_expr().id());
ast_internal::Type type =
checked_impl.GetTypeOrDyn(checked_impl.root_expr().id());
EXPECT_TRUE(type.has_primitive() &&
type.primitive() == ast_internal::PrimitiveType::kInt64);
}
Expand All @@ -160,8 +161,8 @@ class StdlibTypeVarDefinitionTest

TEST_P(StdlibTypeVarDefinitionTest, DefinesTypeConstants) {
auto ast = std::make_unique<AstImpl>();
ast->root_expr().mutable_ident_expr().set_name(GetParam());
ast->root_expr().set_id(1);
ast->mutable_root_expr().mutable_ident_expr().set_name(GetParam());
ast->mutable_root_expr().set_id(1);

ASSERT_OK_AND_ASSIGN(ValidationResult result,
stdlib_type_checker_->Check(std::move(ast)));
Expand All @@ -171,7 +172,7 @@ TEST_P(StdlibTypeVarDefinitionTest, DefinesTypeConstants) {
const auto& checked_impl = AstImpl::CastFromPublicAst(*checked_ast);
EXPECT_THAT(checked_impl.GetReference(1),
Pointee(Property(&Reference::name, GetParam())));
EXPECT_THAT(checked_impl.GetType(1), Property(&AstType::has_type, true));
EXPECT_THAT(checked_impl.GetTypeOrDyn(1), Property(&AstType::has_type, true));
}

INSTANTIATE_TEST_SUITE_P(StdlibTypeVarDefinitions, StdlibTypeVarDefinitionTest,
Expand All @@ -185,7 +186,7 @@ INSTANTIATE_TEST_SUITE_P(StdlibTypeVarDefinitions, StdlibTypeVarDefinitionTest,
TEST_F(StandardLibraryDefinitionsTest, DefinesProtoStructNull) {
auto ast = std::make_unique<AstImpl>();

auto& enumerator = ast->root_expr();
auto& enumerator = ast->mutable_root_expr();
enumerator.set_id(4);
enumerator.mutable_select_expr().set_field("NULL_VALUE");
auto& enumeration = enumerator.mutable_select_expr().mutable_operand();
Expand All @@ -212,7 +213,7 @@ TEST_F(StandardLibraryDefinitionsTest, DefinesProtoStructNull) {
TEST_F(StandardLibraryDefinitionsTest, DefinesTypeType) {
auto ast = std::make_unique<AstImpl>();

auto& ident = ast->root_expr();
auto& ident = ast->mutable_root_expr();
ident.set_id(1);
ident.mutable_ident_expr().set_name("type");

Expand All @@ -224,7 +225,7 @@ TEST_F(StandardLibraryDefinitionsTest, DefinesTypeType) {
const auto& checked_impl = AstImpl::CastFromPublicAst(*checked_ast);
EXPECT_THAT(checked_impl.GetReference(1),
Pointee(Property(&Reference::name, "type")));
EXPECT_THAT(checked_impl.GetType(1), Property(&AstType::has_type, true));
EXPECT_THAT(checked_impl.GetTypeOrDyn(1), Property(&AstType::has_type, true));
}

struct DefinitionsTestCase {
Expand Down
Loading