Skip to content

Commit

Permalink
Add matchers for FunctionBase+Function+Proc.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 541942492
  • Loading branch information
grebe authored and copybara-github committed Jun 20, 2023
1 parent 7d25edd commit fbfe510
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 1 deletion.
1 change: 1 addition & 0 deletions xls/ir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,7 @@ cc_library(
":type",
":value",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:optional",
"@com_google_googletest//:gtest",
],
Expand Down
54 changes: 54 additions & 0 deletions xls/ir/ir_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,15 @@

#include "xls/ir/ir_matcher.h"

#include <ostream>
#include <string>
#include <vector>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "xls/ir/function_base.h"
#include "xls/ir/nodes.h"
#include "xls/ir/type.h"

Expand Down Expand Up @@ -469,6 +476,53 @@ void RegisterMatcher::DescribeTo(::std::ostream* os) const {
}
*os << ")";
}
bool FunctionBaseMatcher::MatchAndExplain(
const ::xls::FunctionBase* fb,
::testing::MatchResultListener* listener) const {
if (fb == nullptr) {
return false;
}
*listener << fb->name();
if (name_.has_value() && fb->name() != *name_) {
*listener << absl::StreamFormat(" has incorrect name %s, expected: %s",
fb->name(), *name_);
return false;
}
return true;
}

void FunctionBaseMatcher::DescribeTo(::std::ostream* os) const {
if (name_.has_value()) {
*os << absl::StreamFormat("FunctionBase(name=%s)", *name_);
} else {
*os << "FunctionBase()";
}
}

void ProcMatcher::DescribeTo(::std::ostream* os) const {
*os << absl::StreamFormat("proc %s { ... }", name_.value_or("<unspecified>"));
}

void ProcMatcher::DescribeNegationTo(std::ostream* os) const {
if (name_.has_value()) {
*os << absl::StreamFormat("FunctionBase was not a proc named %s.", *name_);
} else {
*os << "FunctionBase was not a proc.";
}
}

void FunctionMatcher::DescribeTo(::std::ostream* os) const {
*os << absl::StreamFormat("fn %s { ... }", name_.value_or("<unspecified>"));
}

void FunctionMatcher::DescribeNegationTo(std::ostream* os) const {
if (name_.has_value()) {
*os << absl::StreamFormat("FunctionBase was not a function named %s.",
*name_);
} else {
*os << "FunctionBase was not a function.";
}
}

} // namespace op_matchers
} // namespace xls
128 changes: 128 additions & 0 deletions xls/ir/ir_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,134 @@ inline ::testing::Matcher<const ::xls::Node*> Register(
new ::xls::op_matchers::RegisterMatcher(std::move(input), register_name));
}

// Matcher for FunctionBase. Supported form:
//
// m::FunctionBase(/*name=*/"foo");
//
class FunctionBaseMatcher
: public ::testing::MatcherInterface<const ::xls::FunctionBase*> {
public:
explicit FunctionBaseMatcher(std::optional<std::string> name)
: name_(std::move(name)) {}

bool MatchAndExplain(const ::xls::FunctionBase* fb,
::testing::MatchResultListener* listener) const override;

void DescribeTo(::std::ostream* os) const override;

protected:
std::optional<std::string> name_;
};

inline ::testing::Matcher<const ::xls::FunctionBase*> FunctionBase(
std::optional<std::string> name = std::nullopt) {
return ::testing::MakeMatcher(
new ::xls::op_matchers::FunctionBaseMatcher(std::move(name)));
}

// Matcher for functions. Supported forms:
//
// m::Function(/*name=*/"foo");
// m::Function();
//
class FunctionMatcher {
public:
using is_gtest_matcher = void;

explicit FunctionMatcher(std::optional<std::string> name)
: name_(std::move(name)) {}

template <typename T, typename = absl::enable_if_t<
std::is_convertible_v<T*, ::xls::FunctionBase*>>>
bool MatchAndExplain(const T* fb,
::testing::MatchResultListener* listener) const {
if (fb == nullptr) {
return false;
}
*listener << fb->name();
if (!fb->IsFunction()) {
*listener << " is not a function.";
return false;
}
// Now, match on FunctionBase.
if (!FunctionBase(name_).MatchAndExplain(fb, listener)) {
return false;
}

return true;
}

template <typename T, typename = absl::enable_if_t<
std::is_convertible_v<T*, ::xls::FunctionBase*>>>
bool MatchAndExplain(const std::unique_ptr<T>& fb,
::testing::MatchResultListener* listener) const {
return MatchAndExplain(fb.get(), listener);
}

void DescribeTo(::std::ostream* os) const;
void DescribeNegationTo(std::ostream* os) const;

protected:
std::optional<std::string> name_;
};

inline ::testing::PolymorphicMatcher<FunctionMatcher> Function(
std::optional<std::string> name = std::nullopt) {
return testing::MakePolymorphicMatcher(
::xls::op_matchers::FunctionMatcher(std::move(name)));
}

// Matcher for procs. Supported forms:
//
// m::Proc(/*name=*/"foo");
// m::Proc();
//
class ProcMatcher {
public:
using is_gtest_matcher = void;

explicit ProcMatcher(std::optional<std::string> name)
: name_(std::move(name)) {}

template <typename T, typename = absl::enable_if_t<
std::is_convertible_v<T*, ::xls::FunctionBase*>>>
bool MatchAndExplain(const T* fb,
::testing::MatchResultListener* listener) const {
if (fb == nullptr) {
return false;
}
*listener << fb->name();
if (!fb->IsProc()) {
*listener << " is not a proc.";
return false;
}
// Now, match on FunctionBase.
if (!FunctionBase(name_).MatchAndExplain(fb, listener)) {
return false;
}

return true;
}

template <typename T, typename = absl::enable_if_t<
std::is_convertible_v<T*, ::xls::FunctionBase*>>>
bool MatchAndExplain(const std::unique_ptr<T>& fb,
::testing::MatchResultListener* listener) const {
return MatchAndExplain(fb.get(), listener);
}

void DescribeTo(::std::ostream* os) const;
void DescribeNegationTo(std::ostream* os) const;

protected:
std::optional<std::string> name_;
};

inline ::testing::PolymorphicMatcher<ProcMatcher> Proc(
std::optional<std::string> name = std::nullopt) {
return ::testing::MakePolymorphicMatcher(
::xls::op_matchers::ProcMatcher(std::move(name)));
}
} // namespace op_matchers
} // namespace xls

Expand Down
42 changes: 41 additions & 1 deletion xls/ir/ir_matcher_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ namespace {

using ::testing::_;
using ::testing::AllOf;
using ::testing::Eq;
using ::testing::Contains;
using ::testing::HasSubstr;
using ::testing::UnorderedElementsAre;

template <typename M, typename T>
std::string Explain(const T& t, const M& m) {
Expand Down Expand Up @@ -478,5 +479,44 @@ TEST(IrMatchersTest, RegisterMatcher) {
HasSubstr("has incorrect register (reg), expected: wrong-reg"));
}

TEST(IrMatchersTest, FunctionBaseMatcher) {
Package p("p");
FunctionBuilder fb("f", &p);
auto x = fb.Param("x", p.GetBitsType(32));
auto y = fb.Param("y", p.GetBitsType(32));
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(fb.Add(x, y)));

XLS_ASSERT_OK_AND_ASSIGN(
StreamingChannel * ch0,
p.CreateStreamingChannel("ch0", ChannelOps::kReceiveOnly,
p.GetBitsType(32)));
XLS_ASSERT_OK_AND_ASSIGN(
StreamingChannel * ch1,
p.CreateStreamingChannel("ch1", ChannelOps::kSendOnly,
p.GetBitsType(32)));
ProcBuilder pb("test_proc", "tok", &p);
BValue rcv = pb.Receive(ch0, pb.GetTokenParam());
BValue rcv_token = pb.TupleIndex(rcv, 0);
BValue rcv_data = pb.TupleIndex(rcv, 1);
BValue f_of_data = pb.Invoke({rcv_data, rcv_data}, f);
BValue send_token = pb.Send(ch1, rcv_token, f_of_data);
XLS_ASSERT_OK(pb.Build(send_token, {}).status());

// Match FunctionBases.
EXPECT_THAT(
p.GetFunctionBases(),
UnorderedElementsAre(m::FunctionBase("f"), m::FunctionBase("test_proc")));
EXPECT_THAT(p.GetFunctionBases(), Not(Contains(m::FunctionBase("foobar"))));

// Match Function and Proc.
EXPECT_THAT(p.GetFunctionBases(),
UnorderedElementsAre(m::Function("f"), m::Proc("test_proc")));
EXPECT_THAT(p.GetFunctionBases(), Not(Contains(m::Function("test_proc"))));
EXPECT_THAT(p.GetFunctionBases(), Not(Contains(m::Proc("f"))));

EXPECT_THAT(p.procs(), UnorderedElementsAre(m::Proc("test_proc")));
EXPECT_THAT(p.functions(), UnorderedElementsAre(m::Function("f")));
}

} // namespace
} // namespace xls

0 comments on commit fbfe510

Please sign in to comment.