Skip to content

Commit

Permalink
feat: type-aware tree generation
Browse files Browse the repository at this point in the history
  • Loading branch information
lweitzendorf committed Jul 25, 2022
1 parent 60317cd commit 76db2e6
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 125 deletions.
26 changes: 11 additions & 15 deletions src/gen/gen_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
# SOFTWARE.


import random
import importlib
from typing import Dict, List

Expand Down Expand Up @@ -153,15 +152,16 @@ def get_operator_types(theory: str) -> List[str]:


def get_arities(theory: str) -> List[int]:
arities = []
arities = set()
for op in main_operators[theory].keys():
arities.append(len(get_operator_parameters(theory, op)))
return arities
arities.add(len(get_operator_parameters(theory, op)))
return list(arities)


def get_operator_parameters(theory: str, operator: str) -> List[str]:
params = theory_declarations[theory][operator]
return params
if operator in leaf_operators[theory].values():
return []
return theory_declarations[theory][operator]


def get_constant(theory: str) -> str:
Expand All @@ -180,19 +180,15 @@ def get_root(theory: str) -> str:
return root_operators[theory]


def get_eligible_operator(theory: str, arity: int) -> str:
if theory is None:
theory = random.choice(get_theories())

theory = main_operators[theory]
def get_eligible_operators(theory: str, min_arity: int, max_arity: int) -> List[str]:
operator_choices = []

for operator in theory.keys():
n = len([p for p in theory[operator] if "Operator" in p])
if n == arity:
for operator in main_operators[theory].keys():
n = len(main_operators[theory][operator])
if min_arity <= n <= max_arity:
operator_choices.append(operator)

return random.choice(operator_choices)
return operator_choices


def get_theory_name(theory: str) -> str:
Expand Down
208 changes: 98 additions & 110 deletions src/gen/tree_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# SOFTWARE.


import math
import random
import re

Expand All @@ -29,7 +30,7 @@
from src.gen.gen_configuration import (
get_constant,
get_variable,
get_eligible_operator,
get_eligible_operators,
get_root,
get_operator_class,
get_arities,
Expand All @@ -40,123 +41,110 @@
constant_name_pattern = re.compile(r"^c\d+$")


# Ordered Tree Encoding:
# an n-tuple for a tree of n nodes:
# t = (t1, t2, ..., ti,..., tn)
# ti is the i-th node of the tree t.
# A parent node implies ti > 0, while a leaf node implies ti = 0.
# The nodes are visited in preorder (depth-first left to right) to determine the order of the tree.
# The last node of the tree is thus tn, so tn = 0
# and ti is in [0, n - 1].
# Note: we use a list for simplicity/efficiency.
def _generate_arity_tree(size: int, arities: List[int], min_leaves: int):
"""
Generates a list of integers representing a tree containing n
nodes. The first node is already set to root.
"""
tree = []
rem_leaves = min_leaves
min_op_arity = min(arities)

while len(tree) < size:
max_arity = size - sum(tree) - 1

if 0 < max_arity < min_op_arity:
discrepancy = min_op_arity - max_arity
print(
f"Cannot match tree size {size} exactly, increasing to {size + discrepancy}...")
size += discrepancy
max_arity += discrepancy

rem_operators = size - len(tree) - rem_leaves
max_sub_leaves = (rem_operators - 1) * (max(arities) - 1) + 1
min_arity = rem_leaves - (sum(tree) - len(tree)) - max_sub_leaves + 1

branching_choices = [
n for n in arities if (min_arity <= n <= max_arity)]

if not (sum(tree) == len(tree) and max_arity > 0):
branching_choices.extend([0, 0]) # two different leaf operators

arity = random.choice(branching_choices)
tree.append(arity)

if arity == 0:
rem_leaves -= 1

return tree


def _generate_operator_tree(theory, arity_tree, num_variables) -> Operator:
num_leaves = len([n for n in arity_tree if n == 0])
num_constants = num_leaves - num_variables

if num_leaves < num_variables:
raise ValueError(
"Not enough leaves to accommodate requested number of variables.")

leaves = [False] * num_constants + [True] * num_variables
random.shuffle(leaves)

def recursive_generation(idx, operator_type):
n = arity_tree[idx]
idx += 1
params = []

if n == 0:
nonlocal num_leaves
num_leaves -= 1
is_variable = leaves[num_leaves]

if is_variable:
op_name = get_variable(operator_type)
else:
op_name = get_constant(operator_type)
else:
op_name = get_eligible_operator(operator_type, n)

for input_type in get_operator_parameters(operator_type, op_name):
param, idx = recursive_generation(idx, input_type)
params.append(param)

return get_operator_class(operator_type, op_name)(*params), idx

operator_tree, _ = recursive_generation(0, theory)
root_name = get_root(theory)

output_var = get_operator_class(theory, get_variable(theory))()
root = get_operator_class(theory, root_name)(output_var, operator_tree)

return root


def generate_tree(theory: str, size: int, in_variables: Union[int, List[str]] = 2, out_variable: str = 'z') -> Operator:
if isinstance(in_variables, int):
in_variables = [f'x{i+1}' for i in range(in_variables)]
else:
# Check that names do not clash with constant generated names.
for v in in_variables:
if constant_name_pattern.match(v) is not None:
ValueError(
"The list of variables should not contain a name matching the constant name pattern: c[0-9]")

num_variables = len(in_variables)
arities = get_arities(theory)
if constant_name_pattern.match(v):
ValueError("The list of variables should not "
"contain a name matching the constant name pattern: c[0-9]")

def get_leaf_bounds(tree_type, tree_size):
available_arities = get_arities(tree_type)
min_op_arity = min(available_arities)
max_op_arity = max(available_arities)

def get_bound(arity):
return (tree_size * (arity - 1) + 1) / arity

low = math.ceil(get_bound(min_op_arity))
high = math.floor(get_bound(max_op_arity))
return low, high

gen_size = size - 2 # subtract root and output var from size
gen_num_var = len(in_variables)
min_leaf, max_leaf = get_leaf_bounds(theory, gen_size)

if min_leaf > max_leaf:
gen_size += 1 # maybe print a message here
min_leaf, max_leaf = get_leaf_bounds(theory, gen_size)

if max_leaf < gen_num_var:
raise ValueError(f"Tree of size {size} cannot accommodate {gen_num_var} variables")

gen_num_leaf = random.randint(max(gen_num_var, min_leaf), max_leaf)
gen_num_internal = gen_size - gen_num_leaf
gen_num_const = gen_num_leaf - gen_num_var

leaf_idx = 0
leaf_type = [False] * gen_num_const + [True] * gen_num_var
random.shuffle(leaf_type)

def is_var():
nonlocal leaf_idx
leaf_idx += 1
return leaf_type[leaf_idx-1]

def generate_subtree(op_type, num_internal, num_leaf):
available_arities = get_arities(op_type)

if num_internal >= 1:
max_sub_leaf = (num_internal - 1) * (max(available_arities) - 1) + 1
min_arity = num_leaf - max_sub_leaf + 1
min_sub_leaf = (num_internal - 1) * (min(available_arities) - 1) + 1
max_arity = num_leaf - min_sub_leaf + 1
op_choices = get_eligible_operators(op_type, min_arity, max_arity)
op_name = random.choice(op_choices)
else:
assert num_internal == 0
assert num_leaf == 1
op_func = get_variable if is_var() else get_constant
op_name = op_func(theory)

parameters = get_operator_parameters(theory, op_name)
op_arity = len(parameters)
assert num_leaf >= op_arity

children = []
rem_internal = num_internal - 1 # not sure if max is necessary here
rem_leaf = num_leaf

for idx, input_type in enumerate(parameters):
if idx < op_arity - 1:
# guarantee at least one leaf per subtree
rem_children = op_arity - (idx + 1)
child_leaf = random.randint(1, rem_leaf - rem_children)
rem_leaf -= child_leaf

def get_bound(num_leaves, arity):
return (num_leaves - 1) / (arity - 1) if arity > 1 else 1e6

min_op_arity = min(available_arities)
max_op_arity = max(available_arities)

min_internal_to_cover = math.ceil(get_bound(rem_leaf, max_op_arity))
child_internal_high = math.floor(get_bound(child_leaf, min_op_arity))
child_internal_high = min(child_internal_high, rem_internal - min_internal_to_cover)

child_internal_low = math.ceil(get_bound(child_leaf, max_op_arity))
child_internal_low = min(child_internal_low, child_internal_high)
child_internal = random.randint(child_internal_low, child_internal_high)
rem_internal -= child_internal
else:
child_internal = rem_internal
child_leaf = rem_leaf

if (size - num_variables) * (max(arities) - 1) + 1 < num_variables:
raise ValueError("Tree size too small to accommodate all variables")
child = generate_subtree(input_type, child_internal, child_leaf)
children.append(child)

tree = _generate_arity_tree(size, arities, num_variables)
tree = _generate_operator_tree(theory, tree, num_variables)
return get_operator_class(theory, op_name)(*children)

init_visitor = InitializationVisitor(in_variables, out_variable)
tree.accept(init_visitor)
return tree
operator_tree = generate_subtree(theory, gen_num_internal, gen_num_leaf)
output_var = get_operator_class(theory, get_variable(theory))()

tree = get_operator_class(theory, get_root(theory))(output_var, operator_tree)
tree.accept(InitializationVisitor(in_variables, out_variable))

def validate(tree: List[int]) -> bool:
"""
Returns true if a tree of n nodes respect the property:
sum from i to n of ti = n - 1
"""
return sum(tree) == len(tree) - 1 and tree[len(tree) - 1] == 0
return tree

0 comments on commit 76db2e6

Please sign in to comment.