Skip to content

Commit

Permalink
Merge pull request #369 from datamol-io/fix-label-dtype-error
Browse files Browse the repository at this point in the history
Fix label dtype error
  • Loading branch information
callumm-graphcore authored Jun 28, 2023
2 parents 2566355 + a8a036d commit 0d9e656
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 17 deletions.
16 changes: 14 additions & 2 deletions graphium/data/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
def graphium_collate_fn(
elements: Union[List[Any], Dict[str, List[Any]]],
labels_size_dict: Optional[Dict[str, Any]] = None,
labels_dtype_dict: Optional[Dict[str, Any]] = None,
mask_nan: Union[str, float, Type[None]] = "raise",
do_not_collate_keys: List[str] = [],
batch_size_per_pack: Optional[int] = None,
Expand All @@ -42,6 +43,11 @@ def graphium_collate_fn(
and the size of the label tensor as value. The size of the tensor corresponds to how many
labels/values there are to predict for that task.
labels_dtype_dict:
(Note): This is an attribute of the `MultitaskDataset`.
A dictionary of the form Dict[tasks, dtypes] which has task names as keys
and the dtype of the label tensor as value. This is necessary to ensure the missing labels are added with NaNs of the right dtype
mask_nan:
Deal with the NaN/Inf when calling the function `make_pyg_graph`.
Some values become `Inf` when changing data type. This allows to deal
Expand Down Expand Up @@ -72,7 +78,7 @@ def graphium_collate_fn(
# Multitask setting: We have to pad the missing labels
if key == "labels":
labels = [d[key] for d in elements]
batch[key] = collate_labels(labels, labels_size_dict)
batch[key] = collate_labels(labels, labels_size_dict, labels_dtype_dict)

# If the features are a dictionary containing GraphDict elements,
# Convert to pyg graphs and use the pyg batching.
Expand Down Expand Up @@ -223,6 +229,7 @@ def get_expected_label_size(label_data: Data, task: str, label_size: List[int]):
def collate_labels(
labels: List[Data],
labels_size_dict: Optional[Dict[str, Any]] = None,
labels_dtype_dict: Optional[Dict[str, Any]] = None,
):
"""Collate labels for multitask learning.
Expand All @@ -231,6 +238,10 @@ def collate_labels(
labels_size_dict: Dict of the form Dict[tasks, sizes] which has task names as keys
and the size of the label tensor as value. The size of the tensor corresponds to how many
labels/values there are to predict for that task.
labels_dtype_dict:
(Note): This is an attribute of the `MultitaskDataset`.
A dictionary of the form Dict[tasks, dtypes] which has task names as keys
and the dtype of the label tensor as value. This is necessary to ensure the missing labels are added with NaNs of the right dtype
Returns:
A dictionary of the form Dict[tasks, labels] where tasks is the name of the task and labels
Expand All @@ -248,7 +259,8 @@ def collate_labels(
empty_task_labels = set(labels_size_dict.keys()) - set(this_label.keys)
for task in empty_task_labels:
labels_size_dict[task] = get_expected_label_size(this_label, task, labels_size_dict[task])
this_label[task] = torch.full([*labels_size_dict[task]], torch.nan)
dtype = labels_dtype_dict[task]
this_label[task] = torch.full([*labels_size_dict[task]], torch.nan, dtype=dtype)

for task in set(this_label.keys) - set(["x", "edge_index"]) - empty_task_labels:
labels_size_dict[task] = get_expected_label_size(this_label, task, labels_size_dict[task])
Expand Down
9 changes: 9 additions & 0 deletions graphium/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,7 @@ def setup(
# Can possibly get rid of setup because a single dataset will have molecules exclusively in train, val or test
# Produce the label sizes to update the collate function
labels_size = {}
labels_dtype = {}
if stage == "fit" or stage is None:
if self.load_from_file:
processed_train_data_path = self._path_to_load_from_file("train")
Expand All @@ -1163,6 +1164,8 @@ def setup(
self.train_ds.labels_size
) # Make sure that all task label sizes are contained in here. Maybe do the update outside these if statements.
labels_size.update(self.val_ds.labels_size)
labels_dtype.update(self.train_ds.labels_dtype)
labels_dtype.update(self.val_ds.labels_dtype)

if stage == "test" or stage is None:
if self.load_from_file:
Expand All @@ -1176,12 +1179,18 @@ def setup(
logger.info(self.test_ds)

labels_size.update(self.test_ds.labels_size)
labels_dtype.update(self.test_ds.labels_dtype)

default_labels_size_dict = self.collate_fn.keywords.get("labels_size_dict", None)

if default_labels_size_dict is None:
self.collate_fn.keywords["labels_size_dict"] = labels_size

default_labels_dtype_dict = self.collate_fn.keywords.get("labels_dtype_dict", None)

if default_labels_dtype_dict is None:
self.collate_fn.keywords["labels_dtype_dict"] = labels_dtype

def _make_multitask_dataset(
self,
stage: Literal["train", "val", "test"],
Expand Down
52 changes: 44 additions & 8 deletions graphium/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def __init__(
self.mol_ids = None
self.smiles = None
self.labels_size = self.set_label_size_dict(datasets)
self.labels_dtype = self.set_label_dtype_dict(datasets)
self.dataset_length = len(self.labels)
self._num_nodes_list = None
self._num_edges_list = None
Expand All @@ -219,6 +220,7 @@ def save_metadata(self, directory: str):
"mol_ids",
"smiles",
"labels_size",
"labels_dtype",
"dataset_length",
"_num_nodes_list",
"_num_edges_list",
Expand All @@ -237,13 +239,23 @@ def _load_metadata(self):
"mol_ids",
"smiles",
"labels_size",
"labels_dtype",
"dataset_length",
"_num_nodes_list",
"_num_edges_list",
]
path = os.path.join(self.data_path, "multitask_metadata.pkl")
attrs = torch.load(path)

if not set(attrs_to_load).issubset(set(attrs.keys())):
raise ValueError(
f"The metadata in the cache at {self.data_path} does not contain the right information. "
f"This may be because the cache was prepared using an earlier version of Graphium. "
f"You can try deleting the cache and running the data preparation again. "
f"\nMetadata keys found: {attrs.keys()}"
f"\nMetadata keys required: {attrs_to_load}"
)

for attr, value in attrs.items():
setattr(self, attr, value)

Expand Down Expand Up @@ -512,6 +524,21 @@ def _get_inv_of_mol_ids(self, all_mol_ids):
mol_ids, inv = np.unique(all_mol_ids, return_inverse=True)
return mol_ids, inv

def _find_valid_label(self, task, ds):
r"""
For a given dataset, find a genuine label for that dataset
"""
valid_label = None
for i in range(len(ds)):
if ds[i] is not None:
valid_label = ds[i]["labels"]
break

if valid_label is None:
raise ValueError(f"Dataset for task {task} has no valid labels.")

return valid_label

def set_label_size_dict(self, datasets: Dict[str, SingleTaskDataset]):
r"""
This gives the number of labels to predict for a given task.
Expand All @@ -521,14 +548,7 @@ def set_label_size_dict(self, datasets: Dict[str, SingleTaskDataset]):
if len(ds) == 0:
continue

valid_label = None
for i in range(len(ds)):
if ds[i] is not None:
valid_label = ds[i]["labels"]
break

if valid_label is None:
raise ValueError(f"Dataset for task {task} has no valid labels.")
valid_label = self._find_valid_label(task, ds)

# Assume for a fixed task, the label dimension is the same across data points
torch_label = torch.as_tensor(valid_label)
Expand All @@ -537,6 +557,21 @@ def set_label_size_dict(self, datasets: Dict[str, SingleTaskDataset]):
task_labels_size[task] = torch_label.size()
return task_labels_size

def set_label_dtype_dict(self, datasets: Dict[str, SingleTaskDataset]):
r"""
Gets correct dtype for a given label
"""
task_labels_dtype = {}
for task, ds in datasets.items():
if len(ds) == 0:
continue

valid_label = self._find_valid_label(task, ds)

torch_label = torch.as_tensor(valid_label)
task_labels_dtype[task] = torch_label.dtype
return task_labels_dtype

def __repr__(self) -> str:
"""
summarizes the dataset in a string
Expand Down Expand Up @@ -619,6 +654,7 @@ def __init__(
)

self.labels_size = self.set_label_size_dict(datasets)
self.labels_dtype = self.set_label_dtype_dict(datasets)
self.features = self.features

def _get_inv_of_mol_ids(self, all_mol_ids):
Expand Down
28 changes: 21 additions & 7 deletions tests/test_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,16 @@ def test_collate_labels(self):
"edge_label3": [5, 2],
"node_label4": [5, 1],
}
labels_dtype_dict = {
"graph_label1": torch.float32,
"graph_label2": torch.float16,
"node_label2": torch.float32,
"edge_label3": torch.float32,
"node_label4": torch.float32,
}
fake_label = {
"graph_label1": torch.FloatTensor([1]),
"graph_label2": torch.FloatTensor([1, 2, 3]),
"graph_label2": torch.HalfTensor([1, 2, 3]),
"node_label2": torch.FloatTensor([1, 2, 3, 4, 5]),
"edge_label3": torch.FloatTensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]),
"node_label4": torch.FloatTensor([[1], [2], [3], [4], [5]]),
Expand All @@ -36,14 +43,19 @@ def test_collate_labels(self):
pyg_labels[key] = val + 17 * 2
fake_labels.append(pyg_labels)

# Collate labels and check for the right shapes
collated_labels = collate_labels(deepcopy(fake_labels), deepcopy(labels_size_dict))
# Collate labels and check for the right shapes and dtypes
collated_labels = collate_labels(
deepcopy(fake_labels), deepcopy(labels_size_dict), deepcopy(labels_dtype_dict)
)
self.assertEqual(collated_labels["graph_label1"].shape, torch.Size([num_labels, 1])) # , 1
self.assertEqual(collated_labels["graph_label2"].shape, torch.Size([num_labels, 3])) # , 1
self.assertEqual(collated_labels["node_label2"].shape, torch.Size([num_labels * 5, 1])) # , 5
self.assertEqual(collated_labels["edge_label3"].shape, torch.Size([num_labels * 5, 2])) # , 5, 2
self.assertEqual(collated_labels["node_label4"].shape, torch.Size([num_labels * 5, 1])) # , 5, 1

self.assertEqual(collated_labels["graph_label1"].dtype, torch.float32)
self.assertEqual(collated_labels["graph_label2"].dtype, torch.float16)

# Check that the values are correct
graph_label1_true = deepcopy(torch.stack([this_label["graph_label1"] for this_label in fake_labels]))
graph_label2_true = deepcopy(torch.stack([this_label["graph_label2"] for this_label in fake_labels]))
Expand Down Expand Up @@ -89,7 +101,9 @@ def test_collate_labels(self):
"edge_label3": [5, 2],
"node_label4": [5, 1],
}
collated_labels = collate_labels(deepcopy(fake_labels), deepcopy(labels_size_dict))
collated_labels = collate_labels(
deepcopy(fake_labels), deepcopy(labels_size_dict), deepcopy(labels_dtype_dict)
)
self.assertEqual(collated_labels["graph_label1"].shape, torch.Size([num_labels, 1])) # , 1
self.assertEqual(collated_labels["graph_label2"].shape, torch.Size([num_labels, 3])) # , 1
self.assertEqual(collated_labels["node_label2"].shape, torch.Size([num_labels * 5, 1])) # , 5
Expand All @@ -111,9 +125,9 @@ def test_collate_labels(self):
)
# Now test the `graphium_collate_fn` function when only labels are given
fake_labels2 = [{"labels": this_label} for this_label in fake_labels]
collated_labels = graphium_collate_fn(deepcopy(fake_labels2), labels_size_dict=labels_size_dict)[
"labels"
]
collated_labels = graphium_collate_fn(
deepcopy(fake_labels2), labels_size_dict=labels_size_dict, labels_dtype_dict=labels_dtype_dict
)["labels"]
self.assertEqual(collated_labels["graph_label1"].shape, torch.Size([num_labels, 1]))
self.assertEqual(collated_labels["graph_label2"].shape, torch.Size([num_labels, 3]))
self.assertEqual(collated_labels["node_label2"].shape, torch.Size([num_labels * 5, 1])) # , 5
Expand Down

0 comments on commit 0d9e656

Please sign in to comment.