Skip to content

Commit

Permalink
Adapt to new TopoNetX import convention
Browse files Browse the repository at this point in the history
  • Loading branch information
ffl096 committed Jul 8, 2024
1 parent 04d9437 commit 558532e
Show file tree
Hide file tree
Showing 28 changed files with 230 additions and 1,093 deletions.
5 changes: 3 additions & 2 deletions test/nn/simplicial/test_dist2cycle.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Unit tests for Dist2Cycke Model."""

import numpy as np
import toponetx as tnx
import torch
from toponetx.classes import SimplicialComplex

from topomodelx.nn.simplicial.dist2cycle import Dist2Cycle

Expand All @@ -15,7 +16,7 @@ def test_forward(self):
face_set = [[2, 3, 4], [2, 4, 5]]

torch.manual_seed(42)
simplicial_complex = SimplicialComplex(edge_set + face_set)
simplicial_complex = tnx.SimplicialComplex(edge_set + face_set)
laplacian_down_1 = simplicial_complex.down_laplacian_matrix(rank=1).todense()
adjacency_1 = simplicial_complex.adjacency_matrix(rank=1).todense()
laplacian_down_1_inv = np.linalg.pinv(
Expand Down
5 changes: 3 additions & 2 deletions test/nn/simplicial/test_hsn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Unit tests for HSN Model."""

import numpy as np
import toponetx as tnx
import torch
from toponetx.classes import SimplicialComplex

from topomodelx.nn.simplicial.hsn import HSN

Expand All @@ -15,7 +16,7 @@ def test_forward(self):
face_set = [[2, 3, 4], [2, 4, 5]]

torch.manual_seed(42)
simplicial_complex = SimplicialComplex(edge_set + face_set)
simplicial_complex = tnx.SimplicialComplex(edge_set + face_set)
laplacian_down_1 = simplicial_complex.down_laplacian_matrix(rank=1).todense()
adjacency_1 = simplicial_complex.adjacency_matrix(rank=1).todense()
laplacian_down_1_inv = np.linalg.pinv(
Expand Down
5 changes: 3 additions & 2 deletions test/nn/simplicial/test_san.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Unit tests for SAN Model."""

import itertools
import random

import toponetx as tnx
import torch
from toponetx.classes import SimplicialComplex

from topomodelx.nn.simplicial.san import SAN
from topomodelx.utils.sparse import from_sparse
Expand All @@ -28,7 +29,7 @@ def test_forward(self):
)
random.shuffle(all_combinations)
selected_combinations = all_combinations[:faces]
simplicial_complex = SimplicialComplex()
simplicial_complex = tnx.SimplicialComplex()
for simplex in selected_combinations:
simplicial_complex.add_simplex(simplex)
x_1 = torch.randn(35, 2)
Expand Down
5 changes: 3 additions & 2 deletions test/nn/simplicial/test_sca_cmps.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Unit tests for SCA Model."""

import itertools
import random

import toponetx as tnx
import torch
from toponetx.classes import SimplicialComplex

from topomodelx.nn.simplicial.sca_cmps import SCACMPS
from topomodelx.utils.sparse import from_sparse
Expand All @@ -28,7 +29,7 @@ def test_forward(self):
)
random.shuffle(all_combinations)
selected_combinations = all_combinations[:faces]
simplicial_complex = SimplicialComplex()
simplicial_complex = tnx.SimplicialComplex()
for simplex in selected_combinations:
simplicial_complex.add_simplex(simplex)

Expand Down
5 changes: 3 additions & 2 deletions test/nn/simplicial/test_sccn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Unit tests for SCCN Model."""

import itertools
import random

import toponetx as tnx
import torch
from toponetx.classes import SimplicialComplex

from topomodelx.nn.simplicial.sccn import SCCN
from topomodelx.utils.sparse import from_sparse
Expand All @@ -28,7 +29,7 @@ def test_forward(self):
)
random.shuffle(all_combinations)
selected_combinations = all_combinations[:faces]
simplicial_complex = SimplicialComplex()
simplicial_complex = tnx.SimplicialComplex()
for simplex in selected_combinations:
simplicial_complex.add_simplex(simplex)

Expand Down
5 changes: 3 additions & 2 deletions test/nn/simplicial/test_sccnn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Unit tests for SCCNN Model."""

import itertools
import random

import toponetx as tnx
import torch
from toponetx.classes import SimplicialComplex

from topomodelx.nn.simplicial.sccnn import SCCNN
from topomodelx.utils.sparse import from_sparse
Expand All @@ -28,7 +29,7 @@ def test_forward(self):
)
random.shuffle(all_combinations)
selected_combinations = all_combinations[:faces]
simplicial_complex = SimplicialComplex()
simplicial_complex = tnx.SimplicialComplex()
for simplex in selected_combinations:
simplicial_complex.add_simplex(simplex)
# Some nodes might not be selected at all in the combinations above
Expand Down
5 changes: 3 additions & 2 deletions test/nn/simplicial/test_scconv.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Unit tests for SCCNN Model."""

import itertools
import random

import toponetx as tnx
import torch
from toponetx.classes import SimplicialComplex

from topomodelx.nn.simplicial.scconv import SCConv
from topomodelx.utils.sparse import from_sparse
Expand All @@ -28,7 +29,7 @@ def test_forward(self):
)
random.shuffle(all_combinations)
selected_combinations = all_combinations[:faces]
simplicial_complex = SimplicialComplex()
simplicial_complex = tnx.SimplicialComplex()
for simplex in selected_combinations:
simplicial_complex.add_simplex(simplex)
# Some nodes might not be selected at all in the combinations above
Expand Down
5 changes: 3 additions & 2 deletions test/nn/simplicial/test_scn2.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Unit tests for SCN2 Model."""

import itertools
import random

import toponetx as tnx
import torch
from toponetx.classes import SimplicialComplex

from topomodelx.nn.simplicial.scn2 import SCN2
from topomodelx.utils.sparse import from_sparse
Expand All @@ -28,7 +29,7 @@ def test_forward(self):
)
random.shuffle(all_combinations)
selected_combinations = all_combinations[:faces]
simplicial_complex = SimplicialComplex()
simplicial_complex = tnx.SimplicialComplex()
for simplex in selected_combinations:
simplicial_complex.add_simplex(simplex)

Expand Down
5 changes: 3 additions & 2 deletions test/nn/simplicial/test_scnn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Unit tests for SCNN Model."""

import itertools
import random

import numpy as np
import toponetx as tnx
import torch
from toponetx.classes import SimplicialComplex

from topomodelx.nn.simplicial.scnn import SCNN
from topomodelx.utils.sparse import from_sparse
Expand All @@ -29,7 +30,7 @@ def test_forward(self):
)
random.shuffle(all_combinations)
selected_combinations = all_combinations[:faces]
simplicial_complex = SimplicialComplex()
simplicial_complex = tnx.SimplicialComplex()
for simplex in selected_combinations:
simplicial_complex.add_simplex(simplex)
x_1 = torch.randn(simplicial_complex.shape[1], 2)
Expand Down
13 changes: 8 additions & 5 deletions topomodelx/nn/simplicial/scone.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Neural network implementation of classification using SCoNe."""

import random
from itertools import product

import networkx as nx
import numpy as np
import toponetx as tnx
import torch
from scipy.spatial import Delaunay, distance
from toponetx.classes.simplicial_complex import SimplicialComplex
from torch import nn
from torch.utils.data.dataset import Dataset

Expand All @@ -15,7 +16,7 @@

def generate_complex(
N: int = 100, *, rng: np.random.Generator | None = None
) -> tuple[SimplicialComplex, np.ndarray]:
) -> tuple[tnx.SimplicialComplex, np.ndarray]:
"""Generate a simplicial complex as described.
Generate a simplicial complex of dimension 2 as follows:
Expand Down Expand Up @@ -58,13 +59,13 @@ def generate_complex(
for j in range(3):
simplices[i][j] = idx_dict[simplices[i][j]]

sc = SimplicialComplex(simplices)
sc = tnx.SimplicialComplex(simplices)
coords = points[list(indices_included)]
return sc, coords


def generate_trajectories(
sc: SimplicialComplex, coords: np.ndarray, n_max: int = 1000
sc: tnx.SimplicialComplex, coords: np.ndarray, n_max: int = 1000
) -> list[list[int]]:
"""Generate trajectories from nodes in the lower left corner to the upper right corner connected through a node in the middle."""
# Get indices for start points in the lower left corner, mid points in the center region and end points in the upper right corner.
Expand Down Expand Up @@ -98,7 +99,9 @@ def generate_trajectories(
class TrajectoriesDataset(Dataset):
"""Create a dataset of trajectories."""

def __init__(self, sc: SimplicialComplex, trajectories: list[list[int]]) -> None:
def __init__(
self, sc: tnx.SimplicialComplex, trajectories: list[list[int]]
) -> None:
self.trajectories = trajectories
self.sc = sc
self.adjacency = torch.Tensor(sc.adjacency_matrix(0).toarray())
Expand Down
Loading

0 comments on commit 558532e

Please sign in to comment.