-
Notifications
You must be signed in to change notification settings - Fork 0
/
commands.py
92 lines (68 loc) · 2.47 KB
/
commands.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from abc import ABC, abstractmethod
from typing import Any, Dict, List, FrozenSet, Union
from dataclasses import dataclass
from cvc5 import Kind
from expressions import Expr, IfExpr, add, ConstantExpr
from functools import reduce
class Command(ABC):
@abstractmethod
def variables(self):
pass
@abstractmethod
def eval_complexity(self, functions: List["BaseFunc"]):
pass
@dataclass
class IfCommand(Command):
condition: Expr
true: Command
false: Command
def variables(self):
return self.condition.variables() | self.true.variables() | self.false.variables()
def eval_complexity(self, functions: List["BaseFunc"]):
return IfExpr(self.condition, self.true.eval_complexity(functions), self.false.eval_complexity(functions))
def __repr__(self):
return f"if({self.condition}) then {self.true} else {self.false}"
@dataclass
class FunctionCallCommand(Command):
func_name: str
args: List[Expr]
def variables(self):
res = set()
for arg in self.args:
res |= arg.variables()
return res
def eval_complexity(self, functions: List["BaseFunc"]):
funcs = [func for func in functions if func.name == self.func_name]
if len(funcs) == 0:
raise Exception(f"no function named {self.func_name} to call")
if len(funcs) > 1:
raise Exception(f"multiple functions named {self.func_name} it is ambiguous")
func = funcs[0]
return add(
func.T.substitute_evaluate({name: value for name, value in zip(func.input_names, self.args)}),
ConstantExpr(1)
)
def __repr__(self):
return f"{self.func_name}({','.join([f'{x}' for x in self.args])})"
@dataclass
class BlockCommand(Command):
first: Command
second: Command
def variables(self):
return self.first.variables() | self.second.variables()
def eval_complexity(self, functions: List["BaseFunc"]):
return add(self.first.eval_complexity(functions), self.second.eval_complexity(functions))
def __repr__(self):
return f"{self.first}; {self.second}"
@dataclass
class PassCommand(Command):
def variables(self):
return set()
def eval_complexity(self, functions: List["BaseFunc"]):
return ConstantExpr(0)
def __repr__(self):
return 'pass'
def make_block(commands):
if len(commands) == 0:
return PassCommand()
return reduce(lambda a, b: BlockCommand(a, b), commands)