diff --git a/geometric_kernels/spaces/graph_edge.py b/geometric_kernels/spaces/graph_edge.py index 4ebc2125..955f9274 100644 --- a/geometric_kernels/spaces/graph_edge.py +++ b/geometric_kernels/spaces/graph_edge.py @@ -52,9 +52,13 @@ class GraphEdge(DiscreteSpectrumSpace): """ def __init__(self, G, triangle_list=None, sc_lifting=False): # type: ignore - self.G = G + self.G = nx.Graph() self.cache: Dict[int, Tuple[B.Numeric, B.Numeric]] = {} - self.incidence_matrix = nx.incidence_matrix(G, oriented=True).toarray() # "obtain the oriented incidence matrix" + # reorder the edges in the graph based on the order of the nodes + self.nodes = list(G.nodes) + self.edges = [(min(u, v), max(u, v)) for u, v in G.edges] + self.G.add_edges_from(self.edges) + self.incidence_matrix = nx.incidence_matrix(self.G, oriented=True).toarray() # "obtain the oriented incidence matrix" if sc_lifting is not False: if triangle_list is None: print("No list of triangles is provided, we consider all triangles in the graph as 2-simplices.") @@ -110,8 +114,8 @@ def dimension(self) -> int: def sc_simplices(self): """return the nodes, edges and triangles in the graph""" print('----Simplicial 2-complex summary---') - print('nodes: ', list(self.G.nodes)) - print('edges: ', list(self.G.edges)) + print('nodes: ', self.nodes) + print('edges: ', self.edges) print('triangles: ', self.triangles) return None @@ -150,6 +154,21 @@ def edge_laplacian(self): else: return self._edge_laplacian, self._down_edge_laplacian, self._up_edge_laplacian + def get_edge_index(self, edges): + """ + Get the indices of some provided edges in the edge list. + + Args: + edges (list): Edges. + + Returns: + list: Indices of the edges. + """ + assert isinstance(edges, list) # "The edges should be a list." + # each edge should be in the edge list + assert all(edge in self.edges for edge in edges) + return B.to_numpy([self.edges.index(edge) for edge in edges]) + def triangles_all_clique(self) -> list: """ Get a list of triangles in the graph. @@ -181,23 +200,23 @@ def triangles_to_B2(self) -> np.ndarray: Returns: np.ndarray: B2 matrix. """ - edges = list(self.G.edges) + triangles = self.triangles - B2 = np.zeros((len(edges), len(triangles))) + B2 = np.zeros((len(self.edges), len(triangles))) for j, triangle in enumerate(triangles): a, b, c = triangle try: - index_a = edges.index((a, b)) + index_a = self.edges.index((a, b)) except ValueError: - index_a = edges.index((b, a)) + index_a = self.edges.index((b, a)) try: - index_b = edges.index((b, c)) + index_b = self.edges.index((b, c)) except ValueError: - index_b = edges.index((c, b)) + index_b = self.edges.index((c, b)) try: - index_c = edges.index((a, c)) + index_c = self.edges.index((a, c)) except ValueError: - index_c = edges.index((c, a)) + index_c = self.edges.index((c, a)) B2[index_a, j] = 1 B2[index_c, j] = -1 @@ -363,11 +382,13 @@ def get_repeated_eigenvalues(self, num: int) -> B.Numeric: def random(self, key, number): num_edges = B.shape(self._edge_laplacian)[0] - key, random_edges = B.randint( + key, random_edges_idx = B.randint( key, dtype_integer(key), number, 1, lower=0, upper=num_edges, ) - - return key, random_edges + + random_edges = [self.edges[i] for i in random_edges_idx.flatten().tolist()] + + return key, random_edges, random_edges_idx @property def element_shape(self): diff --git a/notebooks/backends/JAX_SimplicialComplex.ipynb b/notebooks/backends/JAX_SimplicialComplex.ipynb index 89b101b8..ad4db439 100644 --- a/notebooks/backends/JAX_SimplicialComplex.ipynb +++ b/notebooks/backends/JAX_SimplicialComplex.ipynb @@ -294,7 +294,7 @@ "text": [ "----Simplicial 2-complex summary---\n", "nodes: [1, 2, 3, 5, 4]\n", - "edges: [(1, 2), (1, 3), (1, 5), (2, 3), (3, 4), (5, 4)]\n", + "edges: [(1, 2), (1, 3), (1, 5), (2, 3), (3, 4), (4, 5)]\n", "triangles: [(1, 2, 3)]\n" ] } @@ -313,7 +313,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -327,7 +327,7 @@ " [ 0.]])" ] }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -338,7 +338,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -375,7 +375,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -448,7 +448,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -465,7 +465,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -527,7 +527,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -556,7 +556,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -595,7 +595,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -644,6 +644,35 @@ "An explicit `key` parameter is needed to support JAX as one of the backends." ] }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[(1, 2), (1, 3), (1, 5)]\n" + ] + } + ], + "source": [ + "key = PRNGKey(1234)\n", + "\n", + "key, random_sampled_edges, xs = sc.random(key, 3)\n", + "print(random_sampled_edges)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If users prefer input the edges directly, we would need to obtain the indices of the edges in the edge space of the simplicial complex. This can be done by calling `sc.get_edge_index(input_edges)`.\n", + "\n", + "Below we verify that the indices of the sampled edges are correct." + ] + }, { "cell_type": "code", "execution_count": 18, @@ -653,6 +682,9 @@ "name": "stdout", "output_type": "stream", "text": [ + "[[0]\n", + " [1]\n", + " [2]]\n", "[[0]\n", " [1]\n", " [2]]\n" @@ -660,11 +692,10 @@ } ], "source": [ - "key = PRNGKey(1234)\n", - "\n", - "key, xs = sc.random(key, 3)\n", - "\n", - "print(xs)" + "print(xs)\n", + "input_indices = sc.get_edge_index(random_sampled_edges)\n", + "input_indices = jnp.array(input_indices).reshape(-1, 1)\n", + "print(input_indices)" ] }, { @@ -785,8 +816,8 @@ "metadata": {}, "outputs": [], "source": [ - "base_edge = (1,2)\n", - "base_edge_idx = list(G.edges).index(base_edge)\n", + "base_edge = [(1,2)]\n", + "base_edge_idx = sc.get_edge_index(base_edge)[0]\n", "other_edges = jnp.arange(sc.num_edges)[:, None]\n", "edges = [e for e in G.edges()]" ] diff --git a/notebooks/backends/PyTorch_SimplicialComplex.ipynb b/notebooks/backends/PyTorch_SimplicialComplex.ipynb index 0da45e8b..774a936d 100644 --- a/notebooks/backends/PyTorch_SimplicialComplex.ipynb +++ b/notebooks/backends/PyTorch_SimplicialComplex.ipynb @@ -281,7 +281,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -290,7 +290,7 @@ "text": [ "----Simplicial 2-complex summary---\n", "nodes: [1, 2, 3, 5, 4]\n", - "edges: [(1, 2), (1, 3), (1, 5), (2, 3), (3, 4), (5, 4)]\n", + "edges: [(1, 2), (1, 3), (1, 5), (2, 3), (3, 4), (4, 5)]\n", "triangles: [(1, 2, 3)]\n" ] } @@ -309,7 +309,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -323,7 +323,7 @@ " [ 0.]])" ] }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -334,7 +334,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -371,7 +371,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -444,7 +444,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -461,7 +461,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -523,7 +523,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -559,7 +559,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -598,7 +598,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -644,21 +644,21 @@ "metadata": {}, "source": [ "We start by sampling `3` (uniformly) random points on the edge space of our simplicial complex.\n", - "An explicit `key` parameter is needed to support JAX as one of the backends." + "An explicit `key` parameter is needed to support JAX as one of the backends.\n", + "\n", + "**Note:** the here `xs` are the **indicex** of the sampled edges, not the edges themselves." ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "tensor([[4],\n", - " [1],\n", - " [0]], dtype=torch.int32)\n" + "[(3, 4), (1, 3), (1, 2)]\n" ] } ], @@ -666,8 +666,42 @@ "key = torch.Generator()\n", "key.manual_seed(123)\n", "\n", - "key, xs = sc.random(key, 3)\n", - "print(xs)" + "key, random_sampled_edges, xs = sc.random(key, 3)\n", + "print(random_sampled_edges)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If users prefer input the edges directly, we would need to obtain the indices of the edges in the edge space of the simplicial complex. This can be done by calling `sc.get_edge_index(input_edges)`.\n", + "\n", + "Below we verify that the indices of the sampled edges are correct." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[4],\n", + " [1],\n", + " [0]], dtype=torch.int32)\n", + "tensor([[4],\n", + " [1],\n", + " [0]])\n" + ] + } + ], + "source": [ + "print(xs)\n", + "input_indices = sc.get_edge_index(random_sampled_edges)\n", + "input_indices = torch.tensor(input_indices).reshape(-1, 1)\n", + "print(input_indices)" ] }, { @@ -784,12 +818,12 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ - "base_edge = (1,2)\n", - "base_edge_idx = list(G.edges).index(base_edge)\n", + "base_edge = [(1,2)]\n", + "base_edge_idx = sc.get_edge_index(base_edge)[0]\n", "other_edges = torch.arange(sc.num_edges)[:, None]\n", "edges = [e for e in G.edges()]" ] @@ -803,7 +837,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -826,7 +860,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -850,7 +884,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 28, "metadata": {}, "outputs": [ { @@ -953,7 +987,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -1008,7 +1042,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ @@ -1036,7 +1070,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 31, "metadata": {}, "outputs": [ { @@ -1110,7 +1144,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -1157,7 +1191,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 33, "metadata": {}, "outputs": [ {