Skip to content

Commit 664ba99

Browse files
Add seed parameter.
This allows to get the exact same results in one process.
1 parent db45ac6 commit 664ba99

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ edges_kv.save_word2vec_format(EDGES_EMBEDDING_FILENAME)
7373
Use these keys exactly. If not set, will use the global ones which were passed on the object initialization`
7474
10. `quiet`: Boolean controlling the verbosity. (default: False)
7575
11. `temp_folder`: String path pointing to folder to save a shared memory copy of the graph - Supply when working on graphs that are too big to fit in memory during algorithm execution.
76+
12. `seed`: Seed for the random number generator.
7677

7778
- `Node2Vec.fit` method:
7879
Accepts any key word argument acceptable by gensim.Word2Vec

node2vec/node2vec.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import random
12
import os
23
from collections import defaultdict
34

@@ -22,7 +23,7 @@ class Node2Vec:
2223

2324
def __init__(self, graph: nx.Graph, dimensions: int = 128, walk_length: int = 80, num_walks: int = 10, p: float = 1,
2425
q: float = 1, weight_key: str = 'weight', workers: int = 1, sampling_strategy: dict = None,
25-
quiet: bool = False, temp_folder: str = None):
26+
quiet: bool = False, temp_folder: str = None, seed: int = None):
2627
"""
2728
Initiates the Node2Vec object, precomputes walking probabilities and generates the walks.
2829
@@ -35,6 +36,7 @@ def __init__(self, graph: nx.Graph, dimensions: int = 128, walk_length: int = 80
3536
:param weight_key: On weighted graphs, this is the key for the weight attribute (default: 'weight')
3637
:param workers: Number of workers for parallel execution (default: 1)
3738
:param sampling_strategy: Node specific sampling strategies, supports setting node specific 'q', 'p', 'num_walks' and 'walk_length'.
39+
:param seed: Seed for the random number generator.
3840
Use these keys exactly. If not set, will use the global ones which were passed on the object initialization
3941
:param temp_folder: Path to folder with enough space to hold the memory map of self.d_graph (for big graphs); to be passed joblib.Parallel.temp_folder
4042
"""
@@ -63,6 +65,10 @@ def __init__(self, graph: nx.Graph, dimensions: int = 128, walk_length: int = 80
6365
self.temp_folder = temp_folder
6466
self.require = "sharedmem"
6567

68+
if seed is not None:
69+
random.seed(seed)
70+
np.random.seed(seed)
71+
6672
self._precompute_probabilities()
6773
self.walks = self._generate_walks()
6874

0 commit comments

Comments
 (0)