Skip to content

Commit

Permalink
edge space takes edges of node pairs as inputs, instead of indices
Browse files Browse the repository at this point in the history
  • Loading branch information
cookbook-ms committed Oct 21, 2024
1 parent 6dc8334 commit c2de819
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 61 deletions.
51 changes: 36 additions & 15 deletions geometric_kernels/spaces/graph_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,13 @@ class GraphEdge(DiscreteSpectrumSpace):
"""

def __init__(self, G, triangle_list=None, sc_lifting=False): # type: ignore
self.G = G
self.G = nx.Graph()
self.cache: Dict[int, Tuple[B.Numeric, B.Numeric]] = {}
self.incidence_matrix = nx.incidence_matrix(G, oriented=True).toarray() # "obtain the oriented incidence matrix"
# reorder the edges in the graph based on the order of the nodes
self.nodes = list(G.nodes)
self.edges = [(min(u, v), max(u, v)) for u, v in G.edges]
self.G.add_edges_from(self.edges)
self.incidence_matrix = nx.incidence_matrix(self.G, oriented=True).toarray() # "obtain the oriented incidence matrix"
if sc_lifting is not False:
if triangle_list is None:
print("No list of triangles is provided, we consider all triangles in the graph as 2-simplices.")
Expand Down Expand Up @@ -110,8 +114,8 @@ def dimension(self) -> int:
def sc_simplices(self):
"""return the nodes, edges and triangles in the graph"""
print('----Simplicial 2-complex summary---')
print('nodes: ', list(self.G.nodes))
print('edges: ', list(self.G.edges))
print('nodes: ', self.nodes)
print('edges: ', self.edges)
print('triangles: ', self.triangles)
return None

Expand Down Expand Up @@ -150,6 +154,21 @@ def edge_laplacian(self):
else:
return self._edge_laplacian, self._down_edge_laplacian, self._up_edge_laplacian

def get_edge_index(self, edges):
"""
Get the indices of some provided edges in the edge list.
Args:
edges (list): Edges.
Returns:
list: Indices of the edges.
"""
assert isinstance(edges, list) # "The edges should be a list."
# each edge should be in the edge list
assert all(edge in self.edges for edge in edges)
return B.to_numpy([self.edges.index(edge) for edge in edges])

def triangles_all_clique(self) -> list:
"""
Get a list of triangles in the graph.
Expand Down Expand Up @@ -181,23 +200,23 @@ def triangles_to_B2(self) -> np.ndarray:
Returns:
np.ndarray: B2 matrix.
"""
edges = list(self.G.edges)

triangles = self.triangles
B2 = np.zeros((len(edges), len(triangles)))
B2 = np.zeros((len(self.edges), len(triangles)))
for j, triangle in enumerate(triangles):
a, b, c = triangle
try:
index_a = edges.index((a, b))
index_a = self.edges.index((a, b))
except ValueError:
index_a = edges.index((b, a))
index_a = self.edges.index((b, a))
try:
index_b = edges.index((b, c))
index_b = self.edges.index((b, c))
except ValueError:
index_b = edges.index((c, b))
index_b = self.edges.index((c, b))
try:
index_c = edges.index((a, c))
index_c = self.edges.index((a, c))
except ValueError:
index_c = edges.index((c, a))
index_c = self.edges.index((c, a))

B2[index_a, j] = 1
B2[index_c, j] = -1
Expand Down Expand Up @@ -363,11 +382,13 @@ def get_repeated_eigenvalues(self, num: int) -> B.Numeric:

def random(self, key, number):
num_edges = B.shape(self._edge_laplacian)[0]
key, random_edges = B.randint(
key, random_edges_idx = B.randint(
key, dtype_integer(key), number, 1, lower=0, upper=num_edges,
)

return key, random_edges

random_edges = [self.edges[i] for i in random_edges_idx.flatten().tolist()]

return key, random_edges, random_edges_idx

@property
def element_shape(self):
Expand Down
65 changes: 48 additions & 17 deletions notebooks/backends/JAX_SimplicialComplex.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@
"text": [
"----Simplicial 2-complex summary---\n",
"nodes: [1, 2, 3, 5, 4]\n",
"edges: [(1, 2), (1, 3), (1, 5), (2, 3), (3, 4), (5, 4)]\n",
"edges: [(1, 2), (1, 3), (1, 5), (2, 3), (3, 4), (4, 5)]\n",
"triangles: [(1, 2, 3)]\n"
]
}
Expand All @@ -313,7 +313,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand All @@ -327,7 +327,7 @@
" [ 0.]])"
]
},
"execution_count": 10,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -338,7 +338,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -375,7 +375,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -448,7 +448,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -465,7 +465,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 13,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -527,7 +527,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 14,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -556,7 +556,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -595,7 +595,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 16,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -644,6 +644,35 @@
"An explicit `key` parameter is needed to support JAX as one of the backends."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[(1, 2), (1, 3), (1, 5)]\n"
]
}
],
"source": [
"key = PRNGKey(1234)\n",
"\n",
"key, random_sampled_edges, xs = sc.random(key, 3)\n",
"print(random_sampled_edges)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If users prefer input the edges directly, we would need to obtain the indices of the edges in the edge space of the simplicial complex. This can be done by calling `sc.get_edge_index(input_edges)`.\n",
"\n",
"Below we verify that the indices of the sampled edges are correct."
]
},
{
"cell_type": "code",
"execution_count": 18,
Expand All @@ -653,18 +682,20 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[[0]\n",
" [1]\n",
" [2]]\n",
"[[0]\n",
" [1]\n",
" [2]]\n"
]
}
],
"source": [
"key = PRNGKey(1234)\n",
"\n",
"key, xs = sc.random(key, 3)\n",
"\n",
"print(xs)"
"print(xs)\n",
"input_indices = sc.get_edge_index(random_sampled_edges)\n",
"input_indices = jnp.array(input_indices).reshape(-1, 1)\n",
"print(input_indices)"
]
},
{
Expand Down Expand Up @@ -785,8 +816,8 @@
"metadata": {},
"outputs": [],
"source": [
"base_edge = (1,2)\n",
"base_edge_idx = list(G.edges).index(base_edge)\n",
"base_edge = [(1,2)]\n",
"base_edge_idx = sc.get_edge_index(base_edge)[0]\n",
"other_edges = jnp.arange(sc.num_edges)[:, None]\n",
"edges = [e for e in G.edges()]"
]
Expand Down
Loading

0 comments on commit c2de819

Please sign in to comment.