Skip to content

Commit

Permalink
fix and lots of tests (dmlc#2650)
Browse files Browse the repository at this point in the history
  • Loading branch information
BarclayII authored Feb 12, 2021
1 parent cf8a3fb commit 9e63010
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 262 deletions.
2 changes: 1 addition & 1 deletion python/dgl/nn/mxnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def bmm_maybe_select(A, B, index):
return B[index, A, :]
else:
BB = nd.take(B, index, axis=0)
return nd.batch_dot(A.expand_dims(1), BB).squeeze()
return nd.batch_dot(A.expand_dims(1), BB).squeeze(1)

def normalize(x, p=2, axis=1, eps=1e-12):
r"""Performs :math:`L_p` normalization of inputs over specified dimension.
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/nn/pytorch/conv/relgraphconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def basis_message_func(self, edges, etypes):
etypes = th.repeat_interleave(th.arange(len(etypes), device=device),
th.tensor(etypes, device=device))
weight = weight.index_select(0, etypes)
msg = th.bmm(h.unsqueeze(1), weight).squeeze()
msg = th.bmm(h.unsqueeze(1), weight).squeeze(1)

if 'norm' in edges.data:
msg = msg * edges.data['norm']
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/nn/tensorflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def bmm_maybe_select(A, B, index):
return tf.gather(B, flatidx)
else:
BB = tf.gather(B, index)
return tf.squeeze(tf.matmul(tf.expand_dims(A, 1), BB))
return tf.squeeze(tf.matmul(tf.expand_dims(A, 1), BB), 1)


class Identity(layers.Layer):
Expand Down
Loading

0 comments on commit 9e63010

Please sign in to comment.