Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pyg-lib ToHeteroModule #5992

Merged
merged 144 commits into from
Jan 21, 2023
Merged
Changes from 1 commit
Commits
Show all changes
144 commits
Select commit Hold shift + click to select a range
7093fde
initial rough draft commit
puririshi98 Nov 16, 2022
d1d05fa
initial rough draft commit
puririshi98 Nov 16, 2022
bd51d3e
initial rough draft commit
puririshi98 Nov 16, 2022
d10db3e
Merge branch 'master' into to_hetero_pyglib
puririshi98 Nov 16, 2022
f3d25ab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 16, 2022
74577b1
initial rough draft commit
puririshi98 Nov 17, 2022
910ebcf
Merge branch 'to_hetero_pyglib' of https://github.com/pyg-team/pytorc…
puririshi98 Nov 17, 2022
174894c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
0622d10
initial rough draft commit
puririshi98 Nov 17, 2022
c5431b2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
211c00a
initial rough draft commit
puririshi98 Nov 17, 2022
4b53277
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
fc64305
wip
puririshi98 Nov 17, 2022
b83a0b9
wip
puririshi98 Nov 17, 2022
bea2d64
wip
puririshi98 Nov 17, 2022
80486c4
wip
puririshi98 Nov 17, 2022
2bfd5b7
wip
puririshi98 Nov 17, 2022
b25e8db
wip
puririshi98 Nov 17, 2022
944ec45
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
f1ee4eb
wip
puririshi98 Nov 17, 2022
078dfe8
Merge branch 'to_hetero_pyglib' of https://github.com/pyg-team/pytorc…
puririshi98 Nov 17, 2022
27d1d80
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
18787c2
wip
puririshi98 Nov 17, 2022
040feb6
Merge branch 'to_hetero_pyglib' of https://github.com/pyg-team/pytorc…
puririshi98 Nov 17, 2022
d5f5e6b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
f6376aa
Merge branch 'master' into to_hetero_pyglib
puririshi98 Nov 17, 2022
6f0a73d
wip
puririshi98 Nov 17, 2022
1b13091
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
8b12ae0
wip
puririshi98 Nov 17, 2022
0bbbd4f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
901c55d
wip
puririshi98 Nov 17, 2022
555ae84
Merge branch 'to_hetero_pyglib' of https://github.com/pyg-team/pytorc…
puririshi98 Nov 17, 2022
c40d07f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
8e8b829
wip
puririshi98 Nov 17, 2022
999b060
wip
puririshi98 Nov 17, 2022
c312d0e
wip
puririshi98 Nov 17, 2022
b7cf5b4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
3308cee
wip
puririshi98 Nov 17, 2022
72ad4c8
wip
puririshi98 Nov 17, 2022
498dbd9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
2181f1c
wip
puririshi98 Nov 17, 2022
09e4979
wip
puririshi98 Nov 17, 2022
3200be6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
8dba2a4
wip
puririshi98 Nov 17, 2022
4e79b2a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
42575cd
wip
puririshi98 Nov 17, 2022
338b94a
Merge branch 'to_hetero_pyglib' of https://github.com/pyg-team/pytorc…
puririshi98 Nov 17, 2022
301c11c
wip
puririshi98 Nov 17, 2022
e2eed22
wip
puririshi98 Nov 17, 2022
42fe4a0
wip
puririshi98 Nov 17, 2022
7dbf09f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
96b0fd1
wip
Nov 17, 2022
322b56c
wip
Nov 17, 2022
e10f920
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2022
eb8e7b6
wip
Nov 17, 2022
9fec117
wip
Nov 18, 2022
22d51d8
wip
puririshi98 Nov 18, 2022
27852ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2022
e17cc77
wip
puririshi98 Nov 18, 2022
c929d60
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2022
b2a45cd
wip
puririshi98 Nov 18, 2022
01cf12a
Merge branch 'to_hetero_pyglib' of https://github.com/pyg-team/pytorc…
puririshi98 Nov 18, 2022
3f7fc3e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2022
3e7e259
wip
Nov 18, 2022
decc2ae
wip
Nov 18, 2022
dc1748e
wip
Nov 18, 2022
dbc3ce2
Merge branch 'master' into to_hetero_pyglib
puririshi98 Nov 18, 2022
da55860
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2022
5f78a06
wip
puririshi98 Nov 18, 2022
a51c0ab
wip
puririshi98 Nov 18, 2022
dfa66b2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2022
4a14c2f
wip
puririshi98 Nov 18, 2022
bb2c1ed
wip
puririshi98 Nov 18, 2022
995b7ca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2022
cc2fdc6
wip
puririshi98 Nov 18, 2022
dfad552
wip
puririshi98 Nov 18, 2022
37c7b8d
wip
puririshi98 Nov 18, 2022
91f4bb1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2022
cc86f38
wip
puririshi98 Nov 18, 2022
d4aa0a4
wip
puririshi98 Nov 18, 2022
43c3330
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2022
648c653
wip
puririshi98 Nov 18, 2022
01dcd19
Merge branch 'to_hetero_pyglib' of https://github.com/pyg-team/pytorc…
puririshi98 Nov 18, 2022
0ef84ce
wip
puririshi98 Nov 18, 2022
6c1b09d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2022
8e6fb10
wip
puririshi98 Nov 18, 2022
35dd5d1
Merge branch 'to_hetero_pyglib' of https://github.com/pyg-team/pytorc…
puririshi98 Nov 18, 2022
5c2282b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2022
7190a38
wip
puririshi98 Nov 18, 2022
d71eb50
wip
puririshi98 Nov 18, 2022
962abc4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2022
984af57
wip
puririshi98 Nov 18, 2022
407b2bb
wip
puririshi98 Nov 18, 2022
981f63f
wip
puririshi98 Nov 18, 2022
9eec274
measure dict based heterolin
puririshi98 Nov 18, 2022
bba82b8
wip
puririshi98 Nov 18, 2022
73772e5
wip
puririshi98 Nov 18, 2022
bf5a022
wip
puririshi98 Nov 18, 2022
446287a
wip
puririshi98 Nov 18, 2022
03be758
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2022
e24dbd4
wip
puririshi98 Nov 18, 2022
71c031f
wip
puririshi98 Nov 18, 2022
607e089
wip
puririshi98 Nov 18, 2022
07c9f38
wip
puririshi98 Nov 18, 2022
0ebd23f
wip
puririshi98 Nov 18, 2022
afa437f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2022
5c6511a
wip
puririshi98 Nov 18, 2022
8140ac5
Merge branch 'to_hetero_pyglib' of https://github.com/pyg-team/pytorc…
puririshi98 Nov 18, 2022
b48e92a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2022
2fe62f1
wip
puririshi98 Nov 18, 2022
5db076f
wip
puririshi98 Nov 18, 2022
25fd9e6
wip
puririshi98 Nov 18, 2022
aee8fb9
wip
puririshi98 Nov 18, 2022
44e6f8f
wip
puririshi98 Nov 18, 2022
1989b19
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2022
d1fd49e
Update test_to_hetero_module.py
puririshi98 Nov 28, 2022
8ff4d36
comments
puririshi98 Nov 28, 2022
c3a2fbe
Merge branch 'master' into to_hetero_pyglib
puririshi98 Nov 28, 2022
278315b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 28, 2022
960ecf1
Merge branch 'master' into to_hetero_pyglib
puririshi98 Nov 29, 2022
25d6927
Merge branch 'master' into to_hetero_pyglib
puririshi98 Nov 30, 2022
1ae145a
Merge branch 'master' into to_hetero_pyglib
puririshi98 Dec 1, 2022
d6f01ca
Merge branch 'master' of https://github.com/pyg-team/pytorch_geometri…
puririshi98 Dec 1, 2022
0f442c6
Merge branch 'to_hetero_pyglib' of https://github.com/pyg-team/pytorc…
puririshi98 Dec 1, 2022
f497fb5
Merge branch 'master' into to_hetero_pyglib
puririshi98 Dec 2, 2022
d97150f
Merge branch 'master' into to_hetero_pyglib
puririshi98 Dec 5, 2022
63e6bfb
Merge branch 'master' into to_hetero_pyglib
puririshi98 Dec 6, 2022
e4bbf0b
Merge branch 'master' into to_hetero_pyglib
puririshi98 Dec 7, 2022
5ce8297
Merge branch 'master' into to_hetero_pyglib
puririshi98 Dec 8, 2022
7be4e33
Merge branch 'master' into to_hetero_pyglib
puririshi98 Dec 9, 2022
3f1bd1f
Merge branch 'master' into to_hetero_pyglib
puririshi98 Dec 12, 2022
ffabaab
Merge branch 'master' into to_hetero_pyglib
puririshi98 Dec 14, 2022
92382ec
Merge branch 'master' into to_hetero_pyglib
puririshi98 Jan 5, 2023
073c642
Merge branch 'master' into to_hetero_pyglib
puririshi98 Jan 6, 2023
68d79df
Merge branch 'master' into to_hetero_pyglib
puririshi98 Jan 9, 2023
2d9dfb2
Merge branch 'master' into to_hetero_pyglib
puririshi98 Jan 11, 2023
c4befdf
Merge branch 'master' into to_hetero_pyglib
puririshi98 Jan 17, 2023
8148ac4
Merge branch 'master' into to_hetero_pyglib
puririshi98 Jan 18, 2023
0cfd11e
Merge branch 'master' into to_hetero_pyglib
puririshi98 Jan 19, 2023
d207cd8
Merge branch 'master' into to_hetero_pyglib
puririshi98 Jan 20, 2023
37e1372
update
rusty1s Jan 21, 2023
a18c401
update
rusty1s Jan 21, 2023
86d7146
update
rusty1s Jan 21, 2023
8f35fa4
Merge branch 'master' into to_hetero_pyglib
rusty1s Jan 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 17, 2022
commit 174894ccaba73cc8fcad6831b5b4392cbcb4dc57
35 changes: 18 additions & 17 deletions torch_geometric/nn/to_hetero_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def forward(self, x, edge_index):
if _WITH_PYG_LIB:
return ToHeteroModule(module, metadata, aggr, input_map, debug)
else:
transformer = ToHeteroTransformer(module, metadata, aggr, input_map, debug)
transformer = ToHeteroTransformer(module, metadata, aggr, input_map,
debug)
return transformer.transform()


Expand Down Expand Up @@ -174,7 +175,8 @@ def __init__(
# parse out linear layers
for i, submodule in module.modules():
assert submodule_is_msg_passing_or_lin, "Current PyG"
if isinstance(submodule, torch.nn.Linear) or isinstance(submodule, torch_geometric.nn.dense.Linear):
if isinstance(submodule, torch.nn.Linear) or isinstance(
submodule, torch_geometric.nn.dense.Linear):
lin_module_idxs.append(i)

modules.append(submodule)
Expand All @@ -188,10 +190,8 @@ def __init__(
else:
in_ft = layer.in_channels
out_ft = layer.out_channels
heterolin = torch_geometric.nn.dense.HeteroLinear(in_ft,
out_ft,
len(self.node_types)
)
heterolin = torch_geometric.nn.dense.HeteroLinear(
in_ft, out_ft, len(self.node_types))
heterolin.reset_parameters()
modules_nested_list.append(heterolin)
else:
Expand All @@ -208,7 +208,8 @@ def __init__(

self.modules_nested_list = modules_nested_list

def fused_forward(self, x: Tensor, edge_index: Tensor, node_type: Tensor, edge_type: Tensor):
def fused_forward(self, x: Tensor, edge_index: Tensor, node_type: Tensor,
edge_type: Tensor):
r"""
Args:
x: The input node features. :obj:`[num_nodes, in_channels]`
Expand All @@ -229,13 +230,14 @@ def fused_forward(self, x: Tensor, edge_index: Tensor, node_type: Tensor, edge_t
e_idx_type_j = edge_index[:, edge_type == j]
o_j = layer(x, e_idx_type_j)
if j == 0:
out = torch.zeros(x.shape[0], o_j.shape[-1], device=x.device)
out = torch.zeros(x.shape[0], o_j.shape[-1],
device=x.device)
out += o_j
x = out
return x


def dict_forward(self, x_dict: Dict[NodeType, Tensor], edge_index_dict: Dict[EdgeType, Tensor]):
def dict_forward(self, x_dict: Dict[NodeType, Tensor],
edge_index_dict: Dict[EdgeType, Tensor]):
r"""
Args:
x_dict (Dict[str, Tensor]): A dictionary holding node feature
Expand All @@ -248,7 +250,10 @@ def dict_forward(self, x_dict: Dict[NodeType, Tensor], edge_index_dict: Dict[Edg
for layer_idx, typed_layers in enumerate(self.modules_nested_list):
if layer_idx in self.lin_module_idxs:
x = torch.cat([x_j for x_j in x_dict.values()])
node_type = torch.cat([j * torch.ones(x_j.shape[0]) for j, x_j in enumerate(x_dict.values())])
node_type = torch.cat([
j * torch.ones(x_j.shape[0])
for j, x_j in enumerate(x_dict.values())
])
# HeteroLinear layer
o = typed_layers(x, node_type)
o_dict = {}
Expand All @@ -269,11 +274,9 @@ def dict_forward(self, x_dict: Dict[NodeType, Tensor], edge_index_dict: Dict[Edg
x_dict = o_dict
return x


def foward(self, x: Union[Dict[NodeType, Tensor], Tensor],
edge_index:Union[Dict[EdgeType, Tensor], Tensor],
node_type:OptTensor = None,
edge_type:OptTensor = None):
edge_index: Union[Dict[EdgeType, Tensor], Tensor],
node_type: OptTensor = None, edge_type: OptTensor = None):
r"""
Args:
x (Dict[str, Tensor] or Tensor): A dictionary holding node feature
Expand Down Expand Up @@ -306,8 +309,6 @@ def foward(self, x: Union[Dict[NodeType, Tensor], Tensor],
return self.fused_forward(x, edge_index)




class ToHeteroTransformer(Transformer):

aggrs = {
Expand Down