Skip to content

Commit

Permalink
[DSLX:fmt] First attempt at conditional structure formatting.
Browse files Browse the repository at this point in the history
Only ternary-style are permitted to be inline -- multiple statements in a block
or an else-if arm forces it to turn to multi-line form.

PiperOrigin-RevId: 572411106
  • Loading branch information
cdleary authored and copybara-github committed Oct 11, 2023
1 parent a803fb2 commit 8e4d778
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 32 deletions.
180 changes: 150 additions & 30 deletions xls/dslx/fmt/ast_fmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <functional>
#include <optional>
#include <utility>
#include <variant>
#include <vector>

#include "absl/container/flat_hash_map.h"
Expand Down Expand Up @@ -245,31 +246,47 @@ DocRef Fmt(const Binop& n, const Comments& comments, DocArena& arena) {
return ConcatNGroup(arena, pieces);
}

DocRef Fmt(const Block& n, const Comments& comments, DocArena& arena) {
// Note: we only add leading/trailing spaces in the block if add_curls is true.
static DocRef FmtBlock(const Block& n, const Comments& comments,
DocArena& arena, bool add_curls) {
if (n.statements().empty()) {
return ConcatNGroup(arena, {arena.ocurl(), arena.break0(), arena.ccurl()});
if (add_curls) {
return ConcatNGroup(arena,
{arena.ocurl(), arena.break0(), arena.ccurl()});
}
return arena.break0();
}

// We only want to flatten single-statement blocks -- multi-statement blocks
// we always make line breaks between the statements.
if (n.statements().size() == 1) {
std::vector<DocRef> pieces = {arena.ocurl(), arena.break1(),
Fmt(*n.statements()[0], comments, arena)};
std::vector<DocRef> pieces;
if (add_curls) {
pieces = {arena.ocurl(), arena.break1()};
}

pieces.push_back(Fmt(*n.statements()[0], comments, arena));

if (n.trailing_semi()) {
pieces.push_back(arena.semi());
}
pieces.push_back(arena.break1());
pieces.push_back(arena.ccurl());
return ConcatNGroup(arena, pieces);
if (add_curls) {
pieces.push_back(arena.break1());
pieces.push_back(arena.ccurl());
}
return arena.MakeNest(ConcatNGroup(arena, pieces));
}

// Emit a '{' then nest to emit statements with semis, then emit a '}' outside
// the nesting.
std::vector<DocRef> top = {
arena.ocurl(),
};
std::vector<DocRef> top;

if (add_curls) {
top.push_back(arena.ocurl());
top.push_back(arena.hard_line());
}

std::vector<DocRef> nested = {arena.hard_line()};
std::vector<DocRef> nested;
for (size_t i = 0; i < n.statements().size(); ++i) {
const Statement* stmt = n.statements()[i];
nested.push_back(Fmt(*stmt, comments, arena));
Expand All @@ -283,12 +300,18 @@ DocRef Fmt(const Block& n, const Comments& comments, DocArena& arena) {
}

top.push_back(arena.MakeNest(ConcatN(arena, nested)));
top.push_back(arena.hard_line());
top.push_back(arena.ccurl());
if (add_curls) {
top.push_back(arena.hard_line());
top.push_back(arena.ccurl());
}

return ConcatNGroup(arena, top);
}

DocRef Fmt(const Block& n, const Comments& comments, DocArena& arena) {
return FmtBlock(n, comments, arena, /*add_curls=*/true);
}

DocRef Fmt(const Cast& n, const Comments& comments, DocArena& arena) {
DocRef lhs = Fmt(*n.expr(), comments, arena);

Expand Down Expand Up @@ -448,8 +471,85 @@ DocRef Fmt(const String& n, const Comments& comments, DocArena& arena) {
return arena.MakeText(n.ToString());
}

// Creates a group that has the "test portion" of the conditional; i.e.
//
// if <break1> $test_expr <break1> {
static DocRef MakeConditionalTestGroup(const Conditional& n,
const Comments& comments,
DocArena& arena) {
return ConcatNGroup(arena, {
arena.Make(Keyword::kIf),
arena.break1(),
Fmt(*n.test(), comments, arena),
arena.break1(),
arena.ocurl(),
});
}

// When there's an else-if, or multiple statements inside of the blocks, we
// force the formatting to be multi-line.
static DocRef FmtConditionalMultiline(const Conditional& n,
const Comments& comments,
DocArena& arena) {
std::vector<DocRef> pieces = {
MakeConditionalTestGroup(n, comments, arena), arena.hard_line(),
FmtBlock(*n.consequent(), comments, arena, /*add_curls=*/false),
arena.hard_line()};

std::variant<Block*, Conditional*> alternate = n.alternate();
while (std::holds_alternative<Conditional*>(alternate)) {
Conditional* elseif = std::get<Conditional*>(alternate);
alternate = elseif->alternate();
pieces.push_back(arena.ccurl());
pieces.push_back(arena.space());
pieces.push_back(arena.Make(Keyword::kElse));
pieces.push_back(arena.space());
pieces.push_back(MakeConditionalTestGroup(*elseif, comments, arena));
pieces.push_back(arena.hard_line());
pieces.push_back(
FmtBlock(*elseif->consequent(), comments, arena, /*add_curls=*/false));
pieces.push_back(arena.hard_line());
}

XLS_CHECK(std::holds_alternative<Block*>(alternate));

Block* else_block = std::get<Block*>(alternate);
pieces.push_back(arena.ccurl());
pieces.push_back(arena.space());
pieces.push_back(arena.Make(Keyword::kElse));
pieces.push_back(arena.space());
pieces.push_back(arena.ocurl());
pieces.push_back(arena.hard_line());
pieces.push_back(FmtBlock(*else_block, comments, arena, /*add_curls=*/false));
pieces.push_back(arena.hard_line());
pieces.push_back(arena.ccurl());

return ConcatN(arena, pieces);
}

DocRef Fmt(const Conditional& n, const Comments& comments, DocArena& arena) {
XLS_LOG(FATAL) << "handle conditional: " << n.ToString();
// If there's an else-if clause or multi-statement blocks we force it to be
// multi-line.
if (n.HasElseIf() || n.HasMultiStatementBlocks()) {
return FmtConditionalMultiline(n, comments, arena);
}

std::vector<DocRef> pieces = {
MakeConditionalTestGroup(n, comments, arena),
arena.break1(),
FmtBlock(*n.consequent(), comments, arena, /*add_curls=*/false),
arena.break1(),
};

XLS_CHECK(std::holds_alternative<Block*>(n.alternate()));
const Block* else_block = std::get<Block*>(n.alternate());
pieces.push_back(ConcatNGroup(
arena, {arena.ccurl(), arena.break1(), arena.Make(Keyword::kElse),
arena.break1(), arena.ocurl(), arena.break1()}));
pieces.push_back(FmtBlock(*else_block, comments, arena, /*add_curls=*/false));
pieces.push_back(arena.break1());
pieces.push_back(arena.ccurl());
return ConcatNGroup(arena, pieces);
}

DocRef Fmt(const ConstAssert& n, const Comments& comments, DocArena& arena) {
Expand Down Expand Up @@ -694,30 +794,50 @@ DocRef Fmt(const Function& n, const Comments& comments, DocArena& arena) {

DocRef params = FmtParams(n.params(), comments, arena);

std::vector<DocRef> pieces = {fn, arena.break1(), name};
std::vector<DocRef> signature_pieces = {fn, arena.break1(), name};

if (n.IsParametric()) {
pieces.push_back(arena.oangle());
pieces.push_back(FmtJoin<const ParametricBinding*>(
n.parametric_bindings(), FmtParametricBindingPtr, comments, arena));
pieces.push_back(arena.cangle());
signature_pieces.push_back(ConcatNGroup(
arena,
{arena.oangle(),
FmtJoin<const ParametricBinding*>(
n.parametric_bindings(), FmtParametricBindingPtr, comments, arena),
arena.cangle()}));
}

pieces.push_back(arena.break0());
pieces.push_back(params);
pieces.push_back(arena.break1());
signature_pieces.push_back(arena.break0());
signature_pieces.push_back(params);
signature_pieces.push_back(arena.break1());

if (n.return_type() != nullptr) {
pieces.push_back(arena.arrow());
pieces.push_back(arena.break1());
pieces.push_back(Fmt(*n.return_type(), comments, arena));
pieces.push_back(arena.break1());
signature_pieces.push_back(arena.arrow());
signature_pieces.push_back(arena.break1());
signature_pieces.push_back(Fmt(*n.return_type(), comments, arena));
signature_pieces.push_back(arena.break1());
}

return ConcatNGroup(arena, {
ConcatNGroup(arena, pieces),
Fmt(*n.body(), comments, arena),
});
signature_pieces.push_back(arena.ocurl());

// For empty function we don't put spaces between the curls.
if (n.body()->empty()) {
std::vector<DocRef> fn_pieces = {
ConcatNGroup(arena, signature_pieces),
FmtBlock(*n.body(), comments, arena, /*add_curls=*/false),
arena.ccurl(),
};

return ConcatNGroup(arena, fn_pieces);
}

std::vector<DocRef> fn_pieces = {
ConcatNGroup(arena, signature_pieces),
arena.break1(),
FmtBlock(*n.body(), comments, arena, /*add_curls=*/false),
arena.break1(),
arena.ccurl(),
};

return ConcatNGroup(arena, fn_pieces);
}

static DocRef Fmt(const Proc& n, const Comments& comments, DocArena& arena) {
Expand Down
44 changes: 44 additions & 0 deletions xls/dslx/fmt/ast_fmt_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,50 @@ TEST_F(FunctionFmtTest, ConstAssert) {
EXPECT_EQ(got, want);
}

TEST_F(FunctionFmtTest, ConditionalInTernaryStyle) {
const std::string_view original =
"fn f(x:bool,y:u32,z:u32)->u32{if x{y}else{z}}";
XLS_ASSERT_OK_AND_ASSIGN(std::string got, DoFmt(original));
const std::string_view want =
R"(fn f(x: bool, y: u32, z: u32) -> u32 { if x { y } else { z } })";
EXPECT_EQ(got, want);
}

TEST_F(FunctionFmtTest, ConditionalMultiStatementCausesHardBreaks) {
const std::string_view original =
"fn f(x:bool,y:u32,z:u32)->u32{if x{y;z}else{z;y}}";
XLS_ASSERT_OK_AND_ASSIGN(std::string got, DoFmt(original));
const std::string_view want =
R"(fn f(x: bool, y: u32, z: u32) -> u32 {
if x {
y;
z
} else {
z;
y
}
})";
EXPECT_EQ(got, want);
}

TEST_F(FunctionFmtTest, ConditionalWithElseIf) {
const std::string_view original =
"fn f(a:bool[2],x:u32[3])->u32{if a[0]{x[0]}else if "
"a[1]{x[1]}else{x[2]}}";
XLS_ASSERT_OK_AND_ASSIGN(std::string got, DoFmt(original));
const std::string_view want =
R"(fn f(a: bool[2], x: u32[3]) -> u32 {
if a[0] {
x[0]
} else if a[1] {
x[1]
} else {
x[2]
}
})";
EXPECT_EQ(got, want);
}

// -- ModuleFmtTest cases, formatting entire modules

TEST(ModuleFmtTest, TwoSimpleFunctions) {
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/frontend/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ cc_test(
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
"//xls/common:casts",
"//xls/common:xls_gunit",
"//xls/common:xls_gunit_main",
"//xls/common/status:matchers",
Expand Down
13 changes: 13 additions & 0 deletions xls/dslx/frontend/ast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,19 @@ std::string Conditional::ToStringInternal() const {
ToAstNode(alternate_)->ToString());
}

bool Conditional::HasMultiStatementBlocks() const {
if (consequent_->size() > 1) {
return true;
}
return absl::visit(Visitor{
[](const Block* block) { return block->size() > 1; },
[](const Conditional* elseif) {
return elseif->HasMultiStatementBlocks();
},
},
alternate_);
}

// -- class Attr

Attr::~Attr() = default;
Expand Down
10 changes: 10 additions & 0 deletions xls/dslx/frontend/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,8 @@ class Block : public Expr {
trailing_semi_ = false;
}

int64_t size() const { return statements_.size(); }

private:
Precedence GetPrecedenceInternal() const final {
return Precedence::kStrongest;
Expand Down Expand Up @@ -1442,6 +1444,14 @@ class Conditional : public Expr {
Block* consequent() const { return consequent_; }
std::variant<Block*, Conditional*> alternate() const { return alternate_; }

bool HasElseIf() const {
return std::holds_alternative<Conditional*>(alternate());
}

// Returns whether the blocks inside of this (potentially laddered)
// conditional have multiple statements.
bool HasMultiStatementBlocks() const;

private:
Precedence GetPrecedenceInternal() const final {
return Precedence::kStrongest;
Expand Down
12 changes: 10 additions & 2 deletions xls/dslx/frontend/parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "xls/common/casts.h"
#include "xls/common/status/matchers.h"
#include "xls/dslx/command_line_utils.h"
#include "xls/dslx/error_test_utils.h"
Expand Down Expand Up @@ -1396,7 +1397,10 @@ TEST_F(ParserTest, ForFreevars) {
TEST_F(ParserTest, EmptyTernary) { RoundTripExpr("if true {} else {}"); }

TEST_F(ParserTest, TernaryConditional) {
RoundTripExpr("if true { u32:42 } else { u32:24 }", {});
Expr* e = RoundTripExpr("if true { u32:42 } else { u32:24 }", {});

EXPECT_FALSE(down_cast<Conditional*>(e)->HasElseIf());
EXPECT_FALSE(down_cast<Conditional*>(e)->HasMultiStatementBlocks());

RoundTripExpr(R"(if really_long_identifier_so_that_this_is_too_many_chars {
u32:42
Expand All @@ -1407,7 +1411,11 @@ TEST_F(ParserTest, TernaryConditional) {
}

TEST_F(ParserTest, LadderedConditional) {
RoundTripExpr("if true { u32:42 } else if false { u32:33 } else { u32:24 }");
Expr* e = RoundTripExpr(
"if true { u32:42 } else if false { u32:33 } else { u32:24 }");

EXPECT_TRUE(down_cast<Conditional*>(e)->HasElseIf());
EXPECT_FALSE(down_cast<Conditional*>(e)->HasMultiStatementBlocks());

RoundTripExpr(
R"(if really_long_identifier_so_that_this_is_too_many_chars {
Expand Down

0 comments on commit 8e4d778

Please sign in to comment.