Skip to content

Commit d90072f

Browse files
authored
fix some commas and indexes (labmlai#106)
* fix some commas and indexes * equation in GATv2 edited accordingly to paper in archive.org
1 parent 62c5786 commit d90072f

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

labml_nn/graphs/gat/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,10 @@ def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):
140140
g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)
141141
# Now we concatenate to get
142142
# $$\{\overrightarrow{g_1} \Vert \overrightarrow{g_1},
143-
# \overrightarrow{g_1}, \Vert \overrightarrow{g_2},
143+
# \overrightarrow{g_1} \Vert \overrightarrow{g_2},
144144
# \dots, \overrightarrow{g_1} \Vert \overrightarrow{g_N},
145145
# \overrightarrow{g_2} \Vert \overrightarrow{g_1},
146-
# \overrightarrow{g_2}, \Vert \overrightarrow{g_2},
146+
# \overrightarrow{g_2} \Vert \overrightarrow{g_2},
147147
# \dots, \overrightarrow{g_2} \Vert \overrightarrow{g_N}, ...\}$$
148148
g_concat = torch.cat([g_repeat_interleave, g_repeat], dim=-1)
149149
# Reshape so that `g_concat[i, j]` is $\overrightarrow{g_i} \Vert \overrightarrow{g_j}$
@@ -170,7 +170,7 @@ def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):
170170

171171
# We then normalize attention scores (or coefficients)
172172
# $$\alpha_{ij} = \text{softmax}_j(e_{ij}) =
173-
# \frac{\exp(e_{ij})}{\sum_{j \in \mathcal{N}_i} \exp(e_{ij})}$$
173+
# \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}_i} \exp(e_{ik})}$$
174174
#
175175
# where $\mathcal{N}_i$ is the set of nodes connected to $i$.
176176
#

labml_nn/graphs/gatv2/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,10 @@ def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):
184184
g_r_repeat_interleave = g_r.repeat_interleave(n_nodes, dim=0)
185185
# Now we add the two tensors to get
186186
# $$\{\overrightarrow{{g_l}_1} + \overrightarrow{{g_r}_1},
187-
# \overrightarrow{{g_l}_1}, + \overrightarrow{{g_r}_2},
187+
# \overrightarrow{{g_l}_1} + \overrightarrow{{g_r}_2},
188188
# \dots, \overrightarrow{{g_l}_1} +\overrightarrow{{g_r}_N},
189189
# \overrightarrow{{g_l}_2} + \overrightarrow{{g_r}_1},
190-
# \overrightarrow{{g_l}_2}, + \overrightarrow{{g_r}_2},
190+
# \overrightarrow{{g_l}_2} + \overrightarrow{{g_r}_2},
191191
# \dots, \overrightarrow{{g_l}_2} + \overrightarrow{{g_r}_N}, ...\}$$
192192
g_sum = g_l_repeat + g_r_repeat_interleave
193193
# Reshape so that `g_sum[i, j]` is $\overrightarrow{{g_l}_i} + \overrightarrow{{g_r}_j}$
@@ -214,7 +214,7 @@ def forward(self, h: torch.Tensor, adj_mat: torch.Tensor):
214214

215215
# We then normalize attention scores (or coefficients)
216216
# $$\alpha_{ij} = \text{softmax}_j(e_{ij}) =
217-
# \frac{\exp(e_{ij})}{\sum_{j \in \mathcal{N}_i} \exp(e_{ij})}$$
217+
# \frac{\exp(e_{ij})}{\sum_{j' \in \mathcal{N}_i} \exp(e_{ij'})}$$
218218
#
219219
# where $\mathcal{N}_i$ is the set of nodes connected to $i$.
220220
#

0 commit comments

Comments
 (0)