Skip to content
Open
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ To focus on the superoptimizer and not making a comprehensive, realistic assembl

There are many possible improvements:

- **Start state.** Right now it assumes the start state is always the same, which means there is no concept of program input.
- **Program equivalence.** A set of inputs and outputs should be specified such that two programs can actually be tested for equivalence.
- **Pruning.** Many nonsensical programs are generated, which significantly slows it down.
- **More instructions.** There need to be more instructions, especially a conditional instruction, to give the superoptimizer more opportunities to make improvements.

Expand Down
56 changes: 28 additions & 28 deletions assembler.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
import re
from cpu import CPU
from instruction_set import *


INSTRUCTION_REGEX = re.compile(r'(\w+)\s+([-\d]+(?:\s*,\s*[-\d]+)*)')


# Turns a string into a program.
def parse(assembly):
"""
Turns a string into a program
"""
lines = assembly.split('\n')
program = []
cpu = CPU(1)
instructions = []
mem_size = 1
for line in lines:
match = re.match(r'(\w+)\s+([-\d]+)(?:,\s*([-\d]+)(?:,\s*([-\d]+))?)?', line)
line = line.strip()
if line == '':
continue
match = INSTRUCTION_REGEX.fullmatch(line)
if match:
op_str, *args_str = match.groups()
op = cpu.ops[op_str]
args = [int(arg) for arg in args_str if arg is not None]
program.append((op, *args))
return program

# Turns a program into a string.
def output(program):
if len(program) == 0: return "\n"
cpu = CPU(1)
assembly = ""
for instruction in program:
op = instruction[0]
args = instruction[1:]
if op.__name__ == cpu.load.__name__:
assembly += f"LOAD {args[0]}\n"
elif op.__name__ == cpu.swap.__name__:
assembly += f"SWAP {args[0]}, {args[1]}\n"
elif op.__name__ == cpu.xor.__name__:
assembly += f"XOR {args[0]}, {args[1]}\n"
elif op.__name__ == cpu.inc.__name__:
assembly += f"INC {args[0]}\n"
return assembly
op, args_str = match.groups()
args = tuple(int(arg) for arg in args_str.split(","))
operand_types = OPS[op]
if len(args) != len(operand_types):
raise ValueError(f'Wrong number of operands: {line}')
for arg, arg_type in zip(args, operand_types):
if arg_type == 'mem':
if arg < 0:
raise ValueError(f'Negative memory address: {line}')
mem_size = max(arg + 1, mem_size)
instructions.append(Instruction(op, args))
else:
raise ValueError(f'Invalid syntax: {line}')
return Program(tuple(instructions), mem_size)
29 changes: 29 additions & 0 deletions brute_force_equivialence_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from cpu import CPU


class BruteForceEquivalenceChecker:
def __init__(self, program1, bit_width, input_size):
self.program1 = program1
self.bit_width = bit_width
self.max_val = 2 ** bit_width
self.input_size = input_size

def generate_inputs(self, input_size):
"""
Generates all possible tuples of the given size with values ranging from 0 (inclusive)
to `max_val` (exclusive).
"""
if input_size == 0:
yield ()
else:
for x in range(self.max_val):
for rest in self.generate_inputs(input_size - 1):
yield x, *rest

def is_equivalent_to(self, program2):
mem_size = max(self.program1.mem_size, program2.mem_size)
cpu = CPU(mem_size, self.bit_width)
for input in self.generate_inputs(self.input_size):
if cpu.execute(self.program1, input) != cpu.execute(program2, input):
return False
return True
56 changes: 30 additions & 26 deletions cpu.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,34 @@
import assembler


def run(assembly, bit_width, input=()):
"""
Helper function that runs a piece of assembly code.
"""
program = assembler.parse(assembly)
cpu = CPU(program.mem_size, bit_width)
return cpu.execute(program, input)


class CPU:
def __init__(self, max_mem_cells):
def __init__(self, max_mem_cells, bit_width):
self.max_mem_cells = max_mem_cells
self.state = [0] * max_mem_cells
self.ops = {'LOAD': self.load, 'SWAP': self.swap, 'XOR': self.xor, 'INC': self.inc}
self.limit = 2 ** bit_width

def execute(self, program):
state = self.state.copy()
for instruction in program:
op = instruction[0]
args = list(instruction[1:])
args.insert(0, state)
state = op(*args)
return state

def load(self, state, val):
state[0] = val
return state

def swap(self, state, mem1, mem2):
state[mem1], state[mem2] = state[mem2], state[mem1]
return state

def xor(self, state, mem1, mem2):
state[mem1] = state[mem1] ^ state[mem2]
def execute(self, program, input=()):
state = [0] * self.max_mem_cells
state[0: len(input)] = input
for instruction in program.instructions:
match instruction.opcode:
case 'LOAD':
state[0] = instruction.args[0] % self.limit
case 'SWAP':
mem1, mem2 = instruction.args
state[mem1], state[mem2] = state[mem2], state[mem1]
case 'XOR':
mem1, mem2 = instruction.args
state[mem1] ^= state[mem2]
case 'INC':
mem = instruction.args[0]
state[mem] = (state[mem] + 1) % self.limit
return state

def inc(self, state, mem):
state[mem] += 1
return state
35 changes: 35 additions & 0 deletions instruction_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from dataclasses import dataclass


@dataclass
class Instruction:
opcode: str
args: tuple[int, ...]

def __str__(self):
args = ", ".join(str(arg) for arg in self.args)
return f"{self.opcode} {args}"


@dataclass
class Program:
instructions: tuple[Instruction, ...]
"""
The instructions that make up this program
"""

mem_size: int
"""
The amount of memory needed to run this program
"""

def __str__(self):
return "\n".join(str(instr) for instr in self.instructions) + "\n"


OPS = {
"LOAD": ("const",),
"SWAP": ("mem", "mem"),
"XOR": ("mem", "mem"),
"INC": ("mem",)
}
46 changes: 31 additions & 15 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,40 @@
from superoptimizer import *
from superoptimizer import optimize
from cpu import run


def print_optimal_from_code(assembly, max_length, bit_width, debug=False):
print(f"***Source***{assembly}")
state = run(assembly, bit_width)
print("***State***")
print(state)
print()
print("***Optimal***")
print(optimize(assembly, max_length, bit_width, debug=debug))
print("=" * 20)
print()


def main():
# Test 1
assembly = """
LOAD 3
SWAP 0, 1
LOAD 3
SWAP 0, 2
LOAD 3
SWAP 0, 3
LOAD 3
LOAD 3
SWAP 0, 1
LOAD 3
SWAP 0, 2
LOAD 3
SWAP 0, 3
LOAD 3
"""
optimal_from_code(assembly, 4, 4, 5)
print_optimal_from_code(assembly, 4, 2)

# Test 2
state = [0, 2, 1]
optimal_from_state(state, 3, 5)
assembly = """
LOAD 2
SWAP 0, 1
LOAD 1
SWAP 0, 2
"""
print_optimal_from_code(assembly, 3, 2)

## Test 3 - Careful, I don't think this will finish for days.
# state = [2, 4, 6, 8, 10, 12]
# optimal_from_state(state, 10, 15, True)

main()
main()
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pytest
z3-solver
16 changes: 16 additions & 0 deletions smt_based_equivalence_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import z3
from smt_program_simulator import simulate


class SmtBasedEquivalenceChecker:
def __init__(self, program1, bit_width, input_size):
self.solver = z3.Solver()
self.bit_width = bit_width
self.input_size = input_size
self.mem_size = program1.mem_size
self.state1 = simulate(program1, self.mem_size, bit_width, input_size)

def is_equivalent_to(self, program2):
state2 = simulate(program2, self.mem_size, self.bit_width, self.input_size)
programs_are_different = z3.Or(*(value1 != value2 for value1, value2 in zip(self.state1, state2)))
return self.solver.check(programs_are_different) == z3.unsat
32 changes: 32 additions & 0 deletions smt_program_simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import z3


def simulate(program, mem_size, bit_width, input_size):
"""
Simulate the behavior of the program using an SMT solver.

The result will be a list containing, for each memory cell, an SMT value representing the
value that will reside in that memory location after running the program.
"""

def mem_cell(i):
if i < input_size:
return z3.BitVec(f'input{i}', bit_width)
else:
return z3.BitVecVal(0, bit_width)

state = [mem_cell(i) for i in range(mem_size)]
for instruction in program.instructions:
match instruction.opcode:
case 'LOAD':
state[0] = z3.BitVecVal(instruction.args[0], bit_width)
case 'SWAP':
mem1, mem2 = instruction.args
state[mem1], state[mem2] = state[mem2], state[mem1]
case 'XOR':
mem1, mem2 = instruction.args
state[mem1] ^= state[mem2]
case 'INC':
mem = instruction.args[0]
state[mem] += 1
return state
Loading