Skip to content

Commit

Permalink
Add optional max_attempts argument to bipartition tree funcs (#389)
Browse files Browse the repository at this point in the history
* Add optional max_attempts argument to bipartition tree funcs

This patch adds a max_attempts argument to the bipartition_tree funcs in
`gerrychain.tree`. If max_attempts is None, it will default to the
previous behavior (i.e. it will hang if it gets stuck). If max_attempts
is set, it will throw a RuntimeError when `max_attempts` is reached.

This patch is based off code from Robi Huq (via Jeanne Clelland), but
is rewritten. Any mistakes are likely my own.

* Bump up default max_attempts to 10k, just in case

Co-authored-by: Robi Huq <robiohuq@gmail.com>
  • Loading branch information
InnovativeInventor and RobiHuq committed Mar 20, 2022
1 parent ae7e776 commit 005d77a
Showing 1 changed file with 34 additions and 10 deletions.
44 changes: 34 additions & 10 deletions gerrychain/tree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import networkx as nx
from networkx.algorithms import tree

from functools import partial
from .random import random
from collections import deque, namedtuple

Expand Down Expand Up @@ -168,7 +169,8 @@ def bipartition_tree(
spanning_tree=None,
spanning_tree_fn=random_spanning_tree,
balance_edge_fn=find_balanced_edge_cuts_memoization,
choice=random.choice
choice=random.choice,
max_attempts=None
):
"""This function finds a balanced 2 partition of a graph by drawing a
spanning tree and finding an edge to cut that leaves at most an epsilon
Expand All @@ -193,22 +195,30 @@ def bipartition_tree(
:param spanning_tree_fn: The random spanning tree algorithm to use if a spanning
tree is not provided
:param choice: :func:`random.choice`. Can be substituted for testing.
:param max_atempts: The max number of attempts that should be made to bipartition.
"""
populations = {node: graph.nodes[node][pop_col] for node in graph.node_indices}

possible_cuts = []
if spanning_tree is None:
spanning_tree = spanning_tree_fn(graph)

restarts = 0
while len(possible_cuts) == 0:
attempts = 0
while max_attempts is None or attempts < max_attempts:
if restarts == node_repeats:
spanning_tree = spanning_tree_fn(graph)
restarts = 0
h = PopulatedGraph(spanning_tree, populations, pop_target, epsilon)
possible_cuts = balance_edge_fn(h, choice=choice)

if len(possible_cuts) != 0:
return choice(possible_cuts).subset

restarts += 1
attempts += 1

return choice(possible_cuts).subset
raise RuntimeError(f"Could not find a possible cut after {max_attempts} attempts.")


def _bipartition_tree_random_all(
Expand All @@ -222,6 +232,7 @@ def _bipartition_tree_random_all(
spanning_tree_fn=random_spanning_tree,
balance_edge_fn=find_balanced_edge_cuts_memoization,
choice=random.choice,
max_attempts=None
):
"""Randomly bipartitions a graph and returns all cuts."""
populations = {node: graph.nodes[node][pop_col] for node in graph.node_indices}
Expand All @@ -231,12 +242,19 @@ def _bipartition_tree_random_all(
spanning_tree = spanning_tree_fn(graph)

repeat = True
while repeat and len(possible_cuts) == 0:
attempts = 0
while max_attempts is None or attempts < max_attempts:
spanning_tree = spanning_tree_fn(graph)
h = PopulatedGraph(spanning_tree, populations, pop_target, epsilon)
possible_cuts = balance_edge_fn(h, choice=choice)

repeat = repeat_until_valid
return possible_cuts
attempts += 1

if not (repeat and len(possible_cuts) == 0):
return possible_cuts

raise RuntimeError(f"Could not find a possible cut after {max_attempts} attempts.")


def bipartition_tree_random(
Expand Down Expand Up @@ -290,7 +308,13 @@ def bipartition_tree_random(


def recursive_tree_part(
graph, parts, pop_target, pop_col, epsilon, node_repeats=1, method=bipartition_tree
graph,
parts,
pop_target,
pop_col,
epsilon,
node_repeats=1,
method=partial(bipartition_tree, max_attempts=10000)
):
"""Uses :func:`~gerrychain.tree.bipartition_tree` recursively to partition a tree into
``len(parts)`` parts of population ``pop_target`` (within ``epsilon``). Can be used to
Expand Down Expand Up @@ -354,7 +378,7 @@ def get_seed_chunks(
pop_col,
epsilon,
node_repeats=1,
method=bipartition_tree_random
method=partial(bipartition_tree_random, max_attempts=10000)
):
"""
Helper function for recursive_seed_part. Partitions the graph into ``num_chunks`` chunks,
Expand Down Expand Up @@ -474,7 +498,7 @@ def recursive_seed_part_inner(
pop_target,
pop_col,
epsilon,
method=bipartition_tree,
method=partial(bipartition_tree, max_attempts=10000),
node_repeats=1,
n=None,
ceil=None,
Expand Down Expand Up @@ -585,7 +609,7 @@ def recursive_seed_part(
pop_target,
pop_col,
epsilon,
method=bipartition_tree,
method=partial(bipartition_tree, max_attempts=10000),
node_repeats=1,
n=None,
ceil=None
Expand Down Expand Up @@ -621,7 +645,7 @@ def recursive_seed_part(
pop_target,
pop_col,
epsilon,
method=bipartition_tree,
method=partial(bipartition_tree, max_attempts=10000),
node_repeats=node_repeats,
n=n,
ceil=ceil
Expand Down

0 comments on commit 005d77a

Please sign in to comment.