Skip to content

Commit 18c4d6e

Browse files
authored
fix out_dim in pytorch GCN example (#231)
fix out_dim in pytorch GCN example
1 parent 4ab29e4 commit 18c4d6e

File tree

1 file changed

+6
-3
lines changed
  • graphlearn/examples/pytorch/gcn

1 file changed

+6
-3
lines changed

graphlearn/examples/pytorch/gcn/gcn.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,12 @@ def __init__(self,
3535
self.depth = depth
3636
self.drop_rate = drop_rate
3737
self.layers = torch.nn.ModuleList()
38+
self.input_dim = input_dim
39+
self.hidden_dim = hidden_dim
40+
self.output_dim = output_dim
3841
for i in range(depth):
39-
input_dim = input_dim if i == 0 else hidden_dim
40-
output_dim = output_dim if i == depth - 1 else hidden_dim
42+
input_dim = self.input_dim if i == 0 else self.hidden_dim
43+
output_dim = self.output_dim if i == self.depth - 1 else self.hidden_dim
4144
self.layers.append(GCNConv(input_dim, output_dim))
4245

4346
def forward(self, data):
@@ -59,4 +62,4 @@ def forward(self, data):
5962

6063
def reset_parameters(self):
6164
for conv in self.layers:
62-
conv.reset_parameters()
65+
conv.reset_parameters()

0 commit comments

Comments
 (0)