Skip to content

Commit

Permalink
Partial fix of node label ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
WenkelF committed Sep 5, 2024
1 parent 6603014 commit c23dc02
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 32 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ draft/
scripts-expts/
sweeps/
mup/
loc-*

# Data and predictions
graphium/data/ZINC_bench_gnn/
Expand Down
2 changes: 1 addition & 1 deletion graphium/graphium_cpp/features.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,4 @@ std::tuple<std::vector<at::Tensor>, int64_t, int64_t> featurize_smiles(
std::unique_ptr<RDKit::RWMol> parse_mol(
const std::string& smiles_string,
bool explicit_H,
bool ordered = false);
bool ordered = true);
262 changes: 231 additions & 31 deletions tests/test_node_label_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,75 +28,275 @@

class Test_NodeLabelOrdering(ut.TestCase):
def test_node_label_ordering(self):
# Import node labels from parquet fole
parquet_file = "tests/data/dummy_node_label_order_data.parquet"
task_kwargs = {"df_path": parquet_file, "split_val": 0.0, "split_test": 0.0}

# Look at raw data
raw_data = pd.read_parquet("tests/data/dummy_node_label_order_data.parquet")
raw_labels = {
smiles: torch.from_numpy(np.stack([label_1, label_2])).T for (smiles, label_1, label_2) in zip(raw_data["ordered_smiles"], raw_data["node_charges_mulliken"], raw_data["node_charges_lowdin"])
}

# Delete the cache if already exist
if exists(TEMP_CACHE_DATA_PATH):
rm(TEMP_CACHE_DATA_PATH, recursive=True)

###################################################################################################################
### Test I: Test if atom labels are ordered correctly for a single dataset that contains only a single molecule ###
###################################################################################################################

# Import node labels from parquet file
df = pd.DataFrame(
{
"ordered_smiles": ["[C:0][C:1][O:2]"],
"node_labels": [[0., 0., 2.]],
}
)

task_kwargs = {"df": df, "split_val": 0.0, "split_test": 0.0}

# Check datamodule with single task and two labels
task_specific_args = {
"task": {"task_level": "node", "label_cols": ["node_charges_mulliken", "node_charges_lowdin"], "smiles_col": "ordered_smiles", "seed": 42, **task_kwargs},
"task": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task_kwargs},
}

dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH)
dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]})
dm.prepare_data()
dm.setup()

dm.train_ds.return_smiles = True

self.assertEqual(len(dm.train_ds), 10)
dl = dm.train_dataloader()

batch = next(iter(dl))

atom_types = batch["labels"].node_task.squeeze()
atom_types_from_features = batch["features"].feat.argmax(1)

np.testing.assert_array_equal(atom_types, atom_types_from_features)

# Delete the cache if already exist
if exists(TEMP_CACHE_DATA_PATH):
rm(TEMP_CACHE_DATA_PATH, recursive=True)

###################################################################################
### Test II: Two ordered SMILES representing the same molecule in same dataset ###
###################################################################################

# Create input data
df = pd.DataFrame(
{
"ordered_smiles": ["[C:0][C:1][O:2]", "[O:0][C:1][C:2]"],
"node_labels": [[0., 0., 2.], [2., 0., 0.]],
}
)

task_kwargs = {"df": df, "split_val": 0.0, "split_test": 0.0}

# Check datamodule with single task and two labels
task_specific_args = {
"task": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task_kwargs},
}

dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]})
dm.prepare_data()
dm.setup()

dm.train_ds.return_smiles = True

dl = dm.train_dataloader()

batch = next(iter(dl))

atom_types = batch["labels"].node_task.squeeze()
atom_types_from_features = batch["features"].feat.argmax(1)

smiles = batch["smiles"]
unbatched_node_labels = unbatch(batch["labels"].node_task, batch["labels"].batch)
np.testing.assert_array_equal(atom_types_from_features, atom_types)

# Delete the cache if already exist
if exists(TEMP_CACHE_DATA_PATH):
rm(TEMP_CACHE_DATA_PATH, recursive=True)

#############################################################################################
### Test III: Merging two node-level tasks each with different ordering of ordered SMILES ###
### TODO: Will currently fail ###
#############################################################################################

# Create input data
df1 = pd.DataFrame(
{
"ordered_smiles": ["[C:0][C:1][O:2]"],
"node_labels": [[0., 0., 2.]],
}
)

processed_labels = {
smiles[idx]: unbatched_node_labels[idx] for idx in range(len(smiles))
df2 = pd.DataFrame(
{
"ordered_smiles": ["[O:0][C:1][C:2]"],
"node_labels": [[2., 0., 0.]],
}
)

task1_kwargs = {"df": df1, "split_val": 0.0, "split_test": 0.0}
task2_kwargs = {"df": df2, "split_val": 0.0, "split_test": 0.0}

# Check datamodule with single task and two labels
task_specific_args = {
"task1": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task1_kwargs},
"task2": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task2_kwargs},
}

for key in raw_labels.keys():
assert torch.abs(raw_labels[key] - processed_labels[key]).max() < 1e-3
dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]})
dm.prepare_data()
dm.setup()

dm.train_ds.return_smiles = True

dl = dm.train_dataloader()

batch = next(iter(dl))

unbatched_node_labels1 = unbatch(batch["labels"].node_task1, batch["labels"].batch)
unbatched_node_labels2 = unbatch(batch["labels"].node_task2, batch["labels"].batch)
unbatched_node_features = unbatch(batch["features"].feat, batch["features"].batch)

atom_types1 = unbatched_node_labels1[0].squeeze()
atom_types2 = unbatched_node_labels2[0].squeeze()
atom_types_from_features = unbatched_node_features[0].argmax(1)

np.testing.assert_array_equal(atom_types_from_features, atom_types1)
np.testing.assert_array_equal(atom_types_from_features, atom_types2)

# Delete the cache if already exist
if exists(TEMP_CACHE_DATA_PATH):
rm(TEMP_CACHE_DATA_PATH, recursive=True)

# Check datamodule with two tasks with each one label
###############################################################################
### Test IV: Merging node-level task on graph-level task with no node order ###
### NOTE: Works as rdkit does not merge ordered_smiles vs. unordered smiles ###
###############################################################################

# Create input data
df1 = pd.DataFrame(
{
"ordered_smiles": ["CCO"],
"graph_labels": [1.],
}
)

df2 = pd.DataFrame(
{
"ordered_smiles": ["[O:0][C:1][C:2]"],
"node_labels": [[2., 0., 0.]],
}
)

task1_kwargs = {"df": df1, "split_val": 0.0, "split_test": 0.0}
task2_kwargs = {"df": df2, "split_val": 0.0, "split_test": 0.0}

# Check datamodule with single task and two labels
task_specific_args = {
"task_1": {"task_level": "node", "label_cols": ["node_charges_mulliken"], "smiles_col": "ordered_smiles", "seed": 41, **task_kwargs},
"task_2": {"task_level": "node", "label_cols": ["node_charges_lowdin"], "smiles_col": "ordered_smiles", "seed": 43, **task_kwargs},
"task1": {"task_level": "graph", "label_cols": ["graph_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task1_kwargs},
"task2": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task2_kwargs},
}

dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH)
dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]})
dm.prepare_data()
dm.setup()

dm.train_ds.return_smiles = True

self.assertEqual(len(dm.train_ds), 10)

dl = dm.train_dataloader()

batch = next(iter(dl))

atom_types = batch["labels"].node_task2.squeeze()
atom_types_from_features = batch["features"].feat.argmax(1)

# Ignore NaNs
nan_indices = atom_types.isnan()
atom_types_from_features[nan_indices] = 333
atom_types[nan_indices] = 333

np.testing.assert_array_equal(atom_types, atom_types_from_features)

# Delete the cache if already exist
if exists(TEMP_CACHE_DATA_PATH):
rm(TEMP_CACHE_DATA_PATH, recursive=True)

#####################################################################################
### Test V: Merging node-level task on graph-level task with different node order ###
### TODO: Will currently fail ###
#####################################################################################

# Create input data
df1 = pd.DataFrame(
{
"ordered_smiles": ["[C:0][C:1][O:2]"],
"graph_labels": [1.],
}
)

smiles = batch["smiles"]
unbatched_node_labels_1 = unbatch(batch["labels"].node_task_1, batch["labels"].batch)
unbatched_node_labels_2 = unbatch(batch["labels"].node_task_2, batch["labels"].batch)
df2 = pd.DataFrame(
{
"ordered_smiles": ["[O:0][C:1][C:2]"],
"node_labels": [[2., 0., 0.]],
}
)

task1_kwargs = {"df": df1, "split_val": 0.0, "split_test": 0.0}
task2_kwargs = {"df": df2, "split_val": 0.0, "split_test": 0.0}

# Check datamodule with single task and two labels
task_specific_args = {
"task1": {"task_level": "graph", "label_cols": ["graph_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task1_kwargs},
"task2": {"task_level": "node", "label_cols": ["node_labels"], "smiles_col": "ordered_smiles", "seed": 42, **task2_kwargs},
}

dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]})
dm.prepare_data()
dm.setup()

dm.train_ds.return_smiles = True

dl = dm.train_dataloader()

batch = next(iter(dl))

atom_types = batch["labels"].node_task2.squeeze()
atom_types_from_features = batch["features"].feat.argmax(1)

processed_labels = {
smiles[idx]: torch.cat([unbatched_node_labels_1[idx], unbatched_node_labels_2[idx]], dim=-1) for idx in range(len(smiles))
np.testing.assert_array_equal(atom_types, atom_types_from_features)

# Delete the cache if already exist
if exists(TEMP_CACHE_DATA_PATH):
rm(TEMP_CACHE_DATA_PATH, recursive=True)

############################
### Test VI: ... ###
### TODO: To be finished ###
############################

# Create input data
df = pd.DataFrame(
{
"smiles": ["CCO", "OCC", "COC", "[C:0][C:1][O:2]", "[O:0][C:1][C:2]", "[C:0][O:1][C:2]"],
"graph_labels": [0., 0., 1., 0., 0., 1.],
}
)

task_kwargs = {"df": df, "split_val": 0.0, "split_test": 0.0}

# Check datamodule with single task and two labels
task_specific_args = {
"task": {"task_level": "graph", "label_cols": ["graph_labels"], "smiles_col": "smiles", "seed": 42, **task_kwargs},
}

for key in raw_labels.keys():
assert torch.abs(raw_labels[key] - processed_labels[key]).max() < 1e-3
dm = MultitaskFromSmilesDataModule(task_specific_args, processed_graph_data_path=TEMP_CACHE_DATA_PATH, featurization={"atom_property_list_onehot": ["atomic-number"]})
dm.prepare_data()
dm.setup()

dm.train_ds.return_smiles = True

dl = dm.train_dataloader()

batch = next(iter(dl))

# Delete the cache if already exist
if exists(TEMP_CACHE_DATA_PATH):
rm(TEMP_CACHE_DATA_PATH, recursive=True)


if __name__ == "__main__":
Expand Down

0 comments on commit c23dc02

Please sign in to comment.