Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FeatureStore abstraction definition #4534

Merged
merged 23 commits into from
Apr 29, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add override example
  • Loading branch information
mananshah99 committed Apr 28, 2022
commit 65e5e0b23fe00058dfa78a5e7645592cd3d82feb
47 changes: 46 additions & 1 deletion test/data/test_feature_store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Optional

import torch
Expand All @@ -6,6 +7,7 @@
AttrView,
FeatureStore,
TensorAttr,
_field_status,
)
from torch_geometric.typing import FeatureTensorType

Expand All @@ -27,7 +29,7 @@ def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:

# Not set or None indices define the obvious index
if not attr.is_set('index') or index is None:
index = torch.range(0, tensor.shape[0] - 1)
index = torch.arange(0, tensor.shape[0])

# Store the index as a column
self.store[MyFeatureStore.key(attr)] = torch.cat(
Expand Down Expand Up @@ -57,6 +59,31 @@ def __len__(self):
raise NotImplementedError


@dataclass
class MyTensorAttrNoGroupName(TensorAttr):
def __init__(self, attr_name=_field_status.UNSET,
index=_field_status.UNSET):
# Treat group_name as optional, and move it to the end
super().__init__(None, attr_name, index)


@dataclass
class MyFeatureStoreNoGroupName(MyFeatureStore):
# pylint: disable=super-init-not-called
def __init__(self):
FeatureStore.__init__(self, backend='test',
attr_cls=MyTensorAttrNoGroupName)
self.store = {}

@classmethod
def key(cls, attr: TensorAttr):
r"""Define the key as (group_name, attr_name)."""
return attr.attr_name or ''

def __len__(self):
raise NotImplementedError


def test_feature_store():
r"""Tests basic API and indexing functionality of a feature store."""
store = MyFeatureStore()
Expand Down Expand Up @@ -112,3 +139,21 @@ def test_feature_store():
assert store[group_name, attr_name, index] is None
del store[group_name]
assert store[group_name]() is None


def test_feature_store_override():
store = MyFeatureStoreNoGroupName()
tensor = torch.Tensor([[0, 0, 0], [1, 1, 1], [2, 2, 2]])
index = torch.Tensor([0, 1, 2])

attr_name = 'feat'

# Only use attr_name and index, in that order
store[attr_name, index] = tensor

# A few assertions to ensure group_name is not needed
assert isinstance(store[attr_name], AttrView)
assert torch.equal(store[attr_name, index], tensor)
assert torch.equal(store[attr_name][index], tensor)
assert torch.equal(store[attr_name][:], tensor)
assert torch.equal(store[attr_name, :], tensor)