Skip to content

Commit

Permalink
HeteroData: num_features impl (#4504)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Apr 20, 2022
1 parent ecf6374 commit e3891b1
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 6 deletions.
20 changes: 15 additions & 5 deletions test/data/test_hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
edge_index_paper_author = torch.stack([idx_paper[:30], idx_author[:30]], dim=0)
edge_index_author_paper = torch.stack([idx_paper[:30], idx_author[:30]], dim=0)

edge_attr_paper_paper = torch.randn(edge_index_paper_paper.size(1), 8)


def get_edge_index(num_src_nodes, num_dst_nodes, num_edges):
row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.long)
Expand Down Expand Up @@ -77,12 +79,20 @@ def test_hetero_data_functions():
data['paper', 'paper'].edge_index = edge_index_paper_paper
data['paper', 'author'].edge_index = edge_index_paper_author
data['author', 'paper'].edge_index = edge_index_author_paper
assert len(data) == 2
assert sorted(data.keys) == ['edge_index', 'x']
assert 'x' in data and 'edge_index' in data
data['paper', 'paper'].edge_attr = edge_attr_paper_paper
assert len(data) == 3
assert sorted(data.keys) == ['edge_attr', 'edge_index', 'x']
assert 'x' in data and 'edge_index' in data and 'edge_attr' in data
assert data.num_nodes == 15
assert data.num_edges == 110

assert data.num_node_features == {'paper': 16, 'author': 32}
assert data.num_edge_features == {
('paper', 'to', 'paper'): 8,
('paper', 'to', 'author'): 0,
('author', 'to', 'paper'): 0,
}

node_types, edge_types = data.metadata()
assert node_types == ['paper', 'author']
assert edge_types == [
Expand All @@ -99,8 +109,8 @@ def test_hetero_data_functions():

data.y = 0
assert data['y'] == 0 and data.y == 0
assert len(data) == 3
assert sorted(data.keys) == ['edge_index', 'x', 'y']
assert len(data) == 4
assert sorted(data.keys) == ['edge_attr', 'edge_index', 'x', 'y']

del data['paper', 'author']
node_types, edge_types = data.metadata()
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def num_node_features(self) -> int:
def num_features(self) -> int:
r"""Returns the number of features per node in the graph.
Alias for :py:attr:`~num_node_features`."""
return self._store.num_features
return self.num_node_features

@property
def num_edge_features(self) -> int:
Expand Down
22 changes: 22 additions & 0 deletions torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,28 @@ def num_nodes(self) -> Optional[int]:
r"""Returns the number of nodes in the graph."""
return super().num_nodes

@property
def num_node_features(self) -> Dict[NodeType, int]:
r"""Returns the number of features per node type in the graph."""
return {
key: store.num_node_features
for key, store in self._node_store_dict.items()
}

@property
def num_features(self) -> Dict[NodeType, int]:
r"""Returns the number of features per node type in the graph.
Alias for :py:attr:`~num_node_features`."""
return self.num_node_features

@property
def num_edge_features(self) -> Dict[EdgeType, int]:
r"""Returns the number of features per edge type in the graph."""
return {
key: store.num_edge_features
for key, store in self._edge_store_dict.items()
}

def debug(self):
pass # TODO

Expand Down

0 comments on commit e3891b1

Please sign in to comment.