Skip to content

Commit a3e4619

Browse files
committed
Add lattice
1 parent 1d7dc6b commit a3e4619

File tree

1 file changed

+45
-3
lines changed

1 file changed

+45
-3
lines changed

cyaron/graph.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from .utils import *
2-
from .vector import Vector
1+
import math
32
import random
4-
from typing import TypeVar, Callable
3+
import itertools
4+
5+
from .utils import *
6+
from typing import TypeVar, Callable, Tuple, Union, List, Iterable, cast
57

68
__all__ = ["Edge", "Graph"]
79

@@ -326,6 +328,46 @@ def graph(point_count, edge_count, **kwargs):
326328
i += 1
327329
return graph
328330

331+
@staticmethod
332+
def lattice(dim: Union[List[int], Tuple[int, ...]],
333+
nei: int = 1,
334+
directed: bool = False,
335+
mutual: bool = True,
336+
circular: Union[bool, Iterable[bool]] = True):
337+
g = Graph(math.prod(dim), directed)
338+
339+
num = len(dim)
340+
try:
341+
circular = iter(cast(Iterable[bool], circular))
342+
circular = itertools.chain(circular, itertools.repeat(True))
343+
circular = itertools.islice(circular, num)
344+
except TypeError:
345+
circular = itertools.repeat(cast(bool, circular), num)
346+
circular = list(circular)
347+
348+
pre_prod = [1] + list(dim[0:-1])
349+
for i in range(1, num):
350+
pre_prod[i] *= pre_prod[i - 1]
351+
352+
for d in itertools.product(*map(lambda c: range(c), dim)):
353+
u = math.sumprod(d, pre_prod)
354+
for i, cir in zip(range(num), circular):
355+
flag, v = d[i], u
356+
for _ in range(nei):
357+
flag += 1
358+
v += pre_prod[i]
359+
if flag == dim[i]:
360+
if cir and dim[i] > 2:
361+
v -= pre_prod[i] * flag
362+
flag = 0
363+
else:
364+
break
365+
g.add_edge(u + 1, v + 1)
366+
if directed and mutual:
367+
g.add_edge(v + 1, u + 1)
368+
369+
return g
370+
329371
@staticmethod
330372
def DAG(point_count, edge_count, **kwargs):
331373
"""DAG(point_count, edge_count, **kwargs) -> Graph

0 commit comments

Comments
 (0)