Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions node2vec/node2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def _precompute_probabilities(self):
"""

d_graph = self.d_graph
first_travel_done = set()

nodes_generator = self.graph.nodes() if self.quiet \
else tqdm(self.graph.nodes(), desc='Computing transition probabilities')
Expand All @@ -90,7 +89,6 @@ def _precompute_probabilities(self):
d_graph[current_node][self.PROBABILITIES_KEY] = dict()

unnormalized_weights = list()
first_travel_weights = list()
d_neighbors = list()

# Calculate unnormalized weights
Expand All @@ -110,23 +108,25 @@ def _precompute_probabilities(self):

# Assign the unnormalized sampling strategy weight, normalize during random walk
unnormalized_weights.append(ss_weight)
if current_node not in first_travel_done:
first_travel_weights.append(self.graph[current_node][destination].get(self.weight_key, 1))
d_neighbors.append(destination)

# Normalize
unnormalized_weights = np.array(unnormalized_weights)
d_graph[current_node][self.PROBABILITIES_KEY][
source] = unnormalized_weights / unnormalized_weights.sum()

if current_node not in first_travel_done:
unnormalized_weights = np.array(first_travel_weights)
d_graph[current_node][self.FIRST_TRAVEL_KEY] = unnormalized_weights / unnormalized_weights.sum()
first_travel_done.add(current_node)

# Save neighbors
d_graph[current_node][self.NEIGHBORS_KEY] = d_neighbors

# Calculate first_travel weights for source
first_travel_weights = []

for destination in self.graph.neighbors(source):
first_travel_weights.append(self.graph[source][destination].get(self.weight_key, 1))

first_travel_weights = np.array(first_travel_weights)
d_graph[source][self.FIRST_TRAVEL_KEY] = first_travel_weights / first_travel_weights.sum()

def _generate_walks(self) -> list:
"""
Generates the random walks which will be used as the skip-gram input.
Expand Down