Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 206 additions & 27 deletions hypernetx/drawing/rubber_band.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import numpy as np
from scipy.spatial.distance import pdist
from scipy.spatial import ConvexHull
from shapely.geometry import Polygon, Point, LineString
from shapely.ops import unary_union
import alphashape

# increases the default figure size to 8in square.
plt.rcParams["figure.figsize"] = (8, 8)
Expand Down Expand Up @@ -348,6 +351,127 @@ def draw_hyper_labels(H, pos, node_radius={}, ax=None, labels={}, **kwargs):
)


def get_concave_hull(points, alpha=0.5):
"""
Generate a concave hull around points using alpha shapes.
Lower alpha values create more concave/detailed shapes.
"""
if len(points) < 4: # Handle small point sets
return Polygon(points)

try:
hull = alphashape.alphashape(points, alpha)
return hull if hull.is_valid else Polygon(points)
except Exception:
# Fallback to convex hull if alpha shape fails
return Polygon(points[ConvexHull(points).vertices])


def create_rubber_band(nodes_pos, edge_nodes, buffer_distance=0.1, alpha=0.5, offset_scale=0.0):
"""
Create a rubber band shape around nodes that strictly contains only the nodes in the hyperedge.

Args:
nodes_pos: Dict of node positions {node: (x, y)}
edge_nodes: List of nodes in the hyperedge
buffer_distance: Base distance to expand the shape around nodes
alpha: Controls the concaveness (lower = more concave)
offset_scale: Additional scaling factor for the buffer
"""
# Get positions of nodes in the edge
edge_points = np.array([nodes_pos[n] for n in edge_nodes])

# Calculate adaptive buffer based on node distances
if len(edge_points) > 1:
# Calculate pairwise distances between nodes in the edge
distances = pdist(edge_points)
# Use mean distance between nodes to scale the buffer
mean_distance = np.mean(distances)
# Scale buffer_distance by mean distance (using 10% of mean distance as default)
adaptive_buffer = buffer_distance * mean_distance
else:
# For single nodes, use the base buffer_distance
adaptive_buffer = buffer_distance

if len(edge_points) < 3:
# Handle special cases for 1 or 2 nodes
if len(edge_points) == 1:
point = Point(edge_points[0])
return point.buffer(adaptive_buffer)
else:
line = LineString(edge_points)
return line.buffer(adaptive_buffer)

# Create initial concave hull
hull = get_concave_hull(edge_points, alpha)

# Expand using adaptive buffer
hull = hull.buffer(adaptive_buffer)

# Verify and adjust if any non-edge nodes are included
non_edge_nodes = set(nodes_pos.keys()) - set(edge_nodes)
included_points = []
to_include = []

# Handle nodes in the edge
for node in edge_nodes:
point = Point(nodes_pos[node])
# Scale random buffer by adaptive_buffer
random_buffer = adaptive_buffer * (1 + np.random.uniform(0, 0.7))
to_include.append(point.buffer(random_buffer))

# Handle nodes not in the edge
for node in non_edge_nodes:
point = Point(nodes_pos[node])
if hull.contains(point):
# Scale random buffer by adaptive_buffer
random_buffer = adaptive_buffer * (1 + np.random.uniform(0, 0.7))
included_points.append(point.buffer(random_buffer))

if included_points:
# Subtract any areas containing non-edge nodes
include_area = unary_union(to_include)
hull = hull.union(include_area)
excluded_area = unary_union(included_points)
hull = hull.difference(excluded_area)

return hull


def draw_hypergraph(H, pos=None, ax=None, **kwargs):
"""
Draw the hypergraph using improved rubber band visualization.

Args:
H: Hypergraph
pos: Node positions (will be generated if None)
ax: Matplotlib axis
**kwargs: Additional drawing parameters
"""
if ax is None:
ax = plt.gca()

if pos is None:
# Generate layout (you can use your preferred layout algorithm)
pos = generate_layout(H)

# Draw edges
for edge in H.edges():
hull = create_rubber_band(pos, H.edges[edge],
buffer_distance=kwargs.get('buffer_distance', 0.1),
alpha=kwargs.get('alpha', 0.5))

# Convert hull to matplotlib path and plot
x, y = hull.exterior.xy
ax.fill(x, y, alpha=0.3)

# Draw nodes
for node in H.nodes():
ax.plot(*pos[node], 'ko')

return ax


def draw(
H,
pos=None,
Expand All @@ -373,6 +497,7 @@ def draw(
contain_hyper_edges=False,
additional_edges_kwargs={},
return_pos=False,
convex=True,
):
"""
Draw a hypergraph as a Matplotlib figure
Expand Down Expand Up @@ -457,14 +582,18 @@ def draw(
...
contain_hyper_edges: bool
whether the rubber band shoudl be drawn around the location of the edge in the bipartite graph. This may be invisibile unless "with_additional_edges" contains this information.
convex: bool
if True, uses convex hull visualization; if False, uses rubber band visualization

"""

ax = ax or plt.gca()
polys = None

if pos is None:
pos = layout_node_link(H, with_additional_edges, layout=layout, **layout_kwargs)

# Calculate node radius regardless of visualization method
r0 = get_default_radius(H, pos)
a0 = np.pi * r0**2

Expand All @@ -475,23 +604,54 @@ def get_node_radius(v):
return node_radius.get(v, 1) * r0
return node_radius * r0

# guarantee that node radius is a dictionary mapping nodes to values
node_radius = {v: get_node_radius(v) for v in H.nodes}

# for convenience, we are using setdefault to mutate the argument
# however, we need to copy this to prevent side-effects
edges_kwargs = edges_kwargs.copy()
edges_kwargs.setdefault("edgecolors", plt.cm.tab10(np.arange(len(H.edges)) % 10))
edges_kwargs.setdefault("facecolors", "none")

polys = draw_hyper_edges(
H,
pos,
node_radius=node_radius,
ax=ax,
contain_hyper_edges=contain_hyper_edges,
**edges_kwargs
)
# Setup colors
if with_color:
colors = plt.cm.tab10(np.arange(len(H.edges)) % 10)
else:
colors = ['gray'] * len(H.edges)

if not convex:
# Use rubber band visualization
for idx, edge in enumerate(H.edges()):
hull = create_rubber_band(pos, H.edges[edge],
buffer_distance=edges_kwargs.get('buffer_distance', 0.1),
alpha=edges_kwargs.get('alpha', 0.5))

color = colors[idx]
if isinstance(color, np.ndarray):
color = tuple(color)

# Handle both Polygon and MultiPolygon cases
if hull.geom_type == 'MultiPolygon':
for polygon in hull.geoms:
x, y = polygon.exterior.xy
ax.plot(x, y, color=color,
linewidth=edges_kwargs.get('linewidth', 1),
linestyle=edges_kwargs.get('linestyle', '-'),
alpha=edges_kwargs.get('alpha', 1.0))
else:
x, y = hull.exterior.xy
ax.plot(x, y, color=color,
linewidth=edges_kwargs.get('linewidth', 1),
linestyle=edges_kwargs.get('linestyle', '-'),
alpha=edges_kwargs.get('alpha', 1.0))
else:
# Original convex hull visualization code
edges_kwargs = edges_kwargs.copy()
if with_color:
edges_kwargs.setdefault("edgecolors", colors)
edges_kwargs.setdefault("facecolors", "none")

polys = draw_hyper_edges(
H,
pos,
node_radius=node_radius,
ax=ax,
contain_hyper_edges=contain_hyper_edges,
**edges_kwargs
)

if with_additional_edges:
nx.draw_networkx_edges(
Expand All @@ -505,18 +665,37 @@ def get_node_radius(v):
labels = get_frozenset_label(
H.edges, count=with_edge_counts, override=edge_labels
)

draw_hyper_edge_labels(
H,
pos,
polys,
color=edges_kwargs["edgecolors"],
backgroundcolor=(1, 1, 1, edge_label_alpha),
labels=labels,
ax=ax,
edge_labels_on_edge=edge_labels_on_edge,
**edge_labels_kwargs
)

if convex:
# Original edge label drawing for convex hull
draw_hyper_edge_labels(
H,
pos,
polys,
color=colors,
backgroundcolor=(1, 1, 1, edge_label_alpha),
labels=labels,
ax=ax,
edge_labels_on_edge=edge_labels_on_edge,
**edge_labels_kwargs
)
else:
# Draw edge labels for rubber band visualization
for idx, edge in enumerate(H.edges()):
label = labels.get(edge, edge)
center = np.mean([pos[n] for n in H.edges[edge]], axis=0)
color = colors[idx]
if isinstance(color, np.ndarray):
color = tuple(color)
ax.annotate(
label,
center,
color=color,
backgroundcolor=(1, 1, 1, edge_label_alpha),
ha='center',
va='center',
**edge_labels_kwargs
)

if with_node_labels:
labels = get_frozenset_label(
Expand Down