Skip to content

Commit

Permalink
Make IR SSA.
Browse files Browse the repository at this point in the history
There's no Phi node in our IR. Instead, basic blocks take parameters.
I borrowed the idea from Swift Intermediate Language.
  • Loading branch information
rui314 committed Sep 1, 2018
1 parent ac74d2e commit b2d2459
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 87 deletions.
14 changes: 9 additions & 5 deletions 9cc.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,16 +338,17 @@ enum {
IR_NOP,
};

typedef struct BB {
int label;
Vector *ir;
} BB;

typedef struct {
int vn; // virtual register number
int rn; // real register number
} Reg;

typedef struct BB {
int label;
Vector *ir;
Reg *param;
} BB;

typedef struct {
int op;

Expand All @@ -372,6 +373,9 @@ typedef struct {

// For liveness tracking
Vector *kill;

// For SSA
Reg *bbarg;
} IR;

void gen_ir(Program *prog);
Expand Down
168 changes: 91 additions & 77 deletions gen_ir.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,18 @@ static void kill(Reg *r) {
vec_push(ir->kill, r);
}

static void jmp(BB *bb) {
static IR *jmp(BB *bb) {
IR *ir = new_ir(IR_JMP);
ir->bb1 = bb;
return ir;
}

static void imm(Reg *r, int imm) {
static Reg *imm(int imm) {
Reg *r = new_reg();
IR *ir = new_ir(IR_IMM);
ir->r0 = r;
ir->imm = imm;
return r;
}

static Reg *gen_expr(Node *node);
Expand All @@ -83,11 +86,6 @@ static void load(Node *node, Reg *dst, Reg *src) {
ir->size = node->ty->size;
}

static void store(Node *node, Reg *dst, Reg *src) {
IR *ir = emit(IR_STORE, dst, NULL, src);
ir->size = node->ty->size;
}

// In C, all expressions that can be written on the left-hand side of
// the '=' operator must have an address in memory. In other words, if
// you can apply the '&' operator to take an address of some
Expand All @@ -110,12 +108,13 @@ static Reg *gen_lval(Node *node) {
return gen_expr(node->expr);

if (node->op == ND_DOT) {
Reg *r = gen_lval(node->expr);
Reg *r2 = new_reg();
imm(r2, node->ty->offset);
emit(IR_ADD, r, r, r2);
Reg *r1 = new_reg();
Reg *r2 = gen_lval(node->expr);
Reg *r3 = imm(node->ty->offset);
emit(IR_ADD, r1, r2, r3);
kill(r2);
return r;
kill(r3);
return r1;
}

assert(node->op == ND_VARREF);
Expand Down Expand Up @@ -147,67 +146,77 @@ static void gen_stmt(Node *node);

static Reg *gen_expr(Node *node) {
switch (node->op) {
case ND_NUM: {
Reg *r = new_reg();
imm(r, node->val);
return r;
}
case ND_NUM:
return imm(node->val);
case ND_EQ:
return gen_binop(IR_EQ, node);
case ND_NE:
return gen_binop(IR_NE, node);
case ND_LOGAND: {
BB *bb1 = new_bb();
BB *bb2 = new_bb();
BB *bb = new_bb();
BB *set0 = new_bb();
BB *set1 = new_bb();
BB *last = new_bb();

Reg *r = gen_expr(node->lhs);
br(r, bb1, last);
Reg *r1 = gen_expr(node->lhs);
br(r1, bb, set0);
kill(r1);

out = bb1;
out = bb;
Reg *r2 = gen_expr(node->rhs);
emit(IR_MOV, r, r, r2);
br(r2, set1, set0);
kill(r2);
br(r, bb2, last);

out = bb2;
imm(r, 1);
jmp(last);
out = set0;
Reg *r3 = imm(0);
jmp(last)->bbarg = r3;
kill(r3);

out = set1;
Reg *r4 = imm(1);
jmp(last)->bbarg = r4;
kill(r4);

out = last;
return r;
out->param = new_reg();
return out->param;
}
case ND_LOGOR: {
BB *bb = new_bb();
BB *set0 = new_bb();
BB *set1 = new_bb();
BB *last = new_bb();

Reg *r = gen_expr(node->lhs);
br(r, set1, bb);

out = set0;
imm(r, 0);
jmp(last);

out = set1;
imm(r, 1);
jmp(last);
Reg *r1 = gen_expr(node->lhs);
br(r1, set1, bb);
kill(r1);

out = bb;
Reg *r2 = gen_expr(node->rhs);
emit(IR_MOV, r, r, r2);
br(r2, set1, set0);
kill(r2);
br(r, set1, set0);

out = set0;
Reg *r3 = imm(0);
jmp(last)->bbarg = r3;
kill(r3);

out = set1;
Reg *r4 = imm(1);
jmp(last)->bbarg = r4;
kill(r4);

out = last;
return r;
out->param = new_reg();
return out->param;
}
case ND_VARREF:
case ND_DOT: {
Reg *r = gen_lval(node);
load(node, r, r);
return r;
Reg *r1 = new_reg();
Reg *r2 = gen_lval(node);
load(node, r1, r2);
kill(r2);
return r1;
}
case ND_CALL: {
Reg *args[6];
Expand All @@ -233,25 +242,28 @@ static Reg *gen_expr(Node *node) {
return r;
}
case ND_CAST: {
Reg *r = gen_expr(node->expr);
Reg *r1 = gen_expr(node->expr);
if (node->ty->ty != BOOL)
return r;
Reg *r2 = new_reg();
imm(r2, 0);
emit(IR_NE, r, r, r2);
return r1;
Reg *r2 = imm(0);
Reg *r3 = new_reg();
emit(IR_NE, r3, r1, r2);
kill(r1);
kill(r2);
return r;
return r3;
}
case ND_STMT_EXPR:
for (int i = 0; i < node->stmts->len; i++)
gen_stmt(node->stmts->data[i]);
return gen_expr(node->expr);
case '=': {
Reg *rhs = gen_expr(node->rhs);
Reg *lhs = gen_lval(node->lhs);
store(node, lhs, rhs);
kill(lhs);
return rhs;
Reg *r1 = gen_expr(node->rhs);
Reg *r2 = gen_lval(node->lhs);

IR *ir = emit(IR_STORE, NULL, r2, r1);
ir->size = node->ty->size;
kill(r2);
return r1;
}
case '+':
return gen_binop(IR_ADD, node);
Expand All @@ -278,12 +290,13 @@ static Reg *gen_expr(Node *node) {
case ND_SHR:
return gen_binop(IR_SHR, node);
case '~': {
Reg *r = gen_expr(node->expr);
Reg *r2 = new_reg();
imm(r2, -1);
emit(IR_XOR, r, r, r2);
Reg *r1 = new_reg();
Reg *r2 = gen_expr(node->expr);
Reg *r3 = imm(-1);
emit(IR_XOR, r1, r2, r3);
kill(r2);
return r;
kill(r3);
return r1;
}
case ',':
kill(gen_expr(node->lhs));
Expand All @@ -293,31 +306,32 @@ static Reg *gen_expr(Node *node) {
BB *els = new_bb();
BB *last = new_bb();

Reg *r = gen_expr(node->cond);
br(r, then, els);
Reg *r1 = gen_expr(node->cond);
br(r1, then, els);
kill(r1);

out = then;
Reg *r2 = gen_expr(node->then);
emit(IR_MOV, r, r, r2);
jmp(last)->bbarg = r2;
kill(r2);
jmp(last);

out = els;
Reg *r3 = gen_expr(node->els);
emit(IR_MOV, r, r, r3);
kill(r2);
jmp(last);
jmp(last)->bbarg = r3;
kill(r3);

out = last;
return r;
out->param = new_reg();
return out->param;
}
case '!': {
Reg *lhs = gen_expr(node->expr);
Reg *rhs = new_reg();
imm(rhs, 0);
emit(IR_EQ, lhs, lhs, rhs);
kill(rhs);
return lhs;
Reg *r1 = new_reg();
Reg *r2 = gen_expr(node->expr);
Reg *r3 = imm(0);
emit(IR_EQ, r1, r2, r3);
kill(r2);
kill(r3);
return r1;
}
default:
assert(0 && "unknown AST type");
Expand Down Expand Up @@ -410,11 +424,11 @@ static void gen_stmt(Node *node) {

BB *next = new_bb();
Reg *r2 = new_reg();

imm(r2, case_->val);
emit(IR_EQ, r2, r2, r);
Reg *r3 = imm(case_->val);
emit(IR_EQ, r2, r3, r);
br(r2, case_->bb, next);
kill(r2);
kill(r3);
out = next;
}
jmp(node->break_);
Expand Down
7 changes: 5 additions & 2 deletions gen_x86.c
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ static char *argreg(int r, int size) {

static void emit_ir(IR *ir, char *ret) {
int r0 = ir->r0 ? ir->r0->rn : 0;
int r1 = ir->r1 ? ir->r1->rn : 0;
int r2 = ir->r2 ? ir->r2->rn : 0;

switch (ir->op) {
Expand Down Expand Up @@ -121,6 +122,8 @@ static void emit_ir(IR *ir, char *ret) {
emit("shr %s, cl", regs[r0]);
break;
case IR_JMP:
if (ir->bbarg)
emit("mov %s, %s", regs[ir->bb1->param->rn], regs[ir->bbarg->rn]);
emit("jmp .L%d", ir->bb1->label);
break;
case IR_BR:
Expand All @@ -134,7 +137,7 @@ static void emit_ir(IR *ir, char *ret) {
emit("movzb %s, %s", regs[r0], regs8[r0]);
break;
case IR_STORE:
emit("mov [%s], %s", regs[r0], reg(r2, ir->size));
emit("mov [%s], %s", regs[r1], reg(r2, ir->size));
break;
case IR_STORE_ARG:
emit("mov [rbp%d], %s", ir->imm, argreg(ir->imm2, ir->size));
Expand Down Expand Up @@ -185,7 +188,7 @@ void emit_code(Function *fn) {

for (int i = 0; i < fn->bbs->len; i++) {
BB *bb = fn->bbs->data[i];
p(".L%d:\n", bb->label);
p(".L%d:", bb->label);
for (int i = 0; i < bb->ir->len; i++) {
IR *ir = bb->ir->data[i];
emit_ir(ir, ret);
Expand Down
10 changes: 8 additions & 2 deletions irdump.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ static char *tostr(IR *ir) {
case IR_IMM:
return format("r%d = %d", r0, ir->imm);
case IR_JMP:
if (ir->bbarg)
return format("JMP .L%d (r%d)", ir->bb1->label, regno(ir->bbarg));
return format("JMP .L%d", ir->bb1->label);
case IR_LABEL_ADDR:
return format("r%d = .L%d", r0, ir->label);
Expand Down Expand Up @@ -69,7 +71,7 @@ static char *tostr(IR *ir) {
case IR_RETURN:
return format("RET r%d", r0);
case IR_STORE:
return format("STORE%d r%d, r%d", ir->size, r0, r2);
return format("STORE%d r%d, r%d", ir->size, r1, r2);
case IR_STORE_ARG:
return format("STORE_ARG%d %d, %d", ir->size, ir->imm, ir->imm2);
case IR_SUB:
Expand All @@ -90,7 +92,11 @@ void dump_ir(Vector *irv) {

for (int i = 0; i < fn->bbs->len; i++) {
BB *bb = fn->bbs->data[i];
fprintf(stderr, ".L%d:\n", bb->label);

if (bb->param)
fprintf(stderr, ".L%d(r%d):\n", bb->label, regno(bb->param));
else
fprintf(stderr, ".L%d:\n", bb->label);

for (int i = 0; i < bb->ir->len; i++) {
IR *ir = bb->ir->data[i];
Expand Down
Loading

0 comments on commit b2d2459

Please sign in to comment.