Skip to content

Commit

Permalink
Merge branch 'graphium_3.0' into torchmetrics
Browse files Browse the repository at this point in the history
  • Loading branch information
AnujaSomthankar authored Sep 10, 2024
2 parents 2fb7f4b + c23dc02 commit ce4f94d
Show file tree
Hide file tree
Showing 7 changed files with 324 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ draft/
scripts-expts/
sweeps/
mup/
loc-*

# Data and predictions
graphium/data/ZINC_bench_gnn/
Expand Down
4 changes: 2 additions & 2 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies:
- platformdirs

# scientific
- numpy < 2.0 # Issue with wandb
- numpy == 1.26.4
- scipy >=1.4
- pandas >=1.0
- scikit-learn
Expand Down Expand Up @@ -41,7 +41,7 @@ dependencies:
- pytorch_scatter >=2.0

# chemistry
- rdkit
- rdkit == 2024.03.4
- datamol >=0.10
- boost # needed by rdkit

Expand Down
6 changes: 6 additions & 0 deletions graphium/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
num_edges_tensor=None,
about: str = "",
data_path: Optional[Union[str, os.PathLike]] = None,
return_smiles: bool = False,
):
r"""
This class holds the information for the multitask dataset.
Expand Down Expand Up @@ -75,6 +76,8 @@ def __init__(
self.num_edges_tensor = num_edges_tensor
self.dataset_length = num_nodes_tensor.size(dim=0)

self.return_smiles = return_smiles

logger.info(f"Dataloading from DISK")

def __len__(self):
Expand Down Expand Up @@ -199,6 +202,9 @@ def __getitem__(self, idx):
"features": self.featurize_smiles(smiles_str),
}

if self.return_smiles:
datum["smiles"] = smiles_str

# One of the featurization error handling options returns a string on error,
# instead of throwing an exception, so assume that the intention is to just skip,
# instead of crashing.
Expand Down
2 changes: 1 addition & 1 deletion graphium/graphium_cpp/features.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,4 @@ std::tuple<std::vector<at::Tensor>, int64_t, int64_t> featurize_smiles(
std::unique_ptr<RDKit::RWMol> parse_mol(
const std::string& smiles_string,
bool explicit_H,
bool ordered = false);
bool ordered = true);
9 changes: 7 additions & 2 deletions graphium/trainer/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,6 @@ def training_step(self, batch: Dict[str, Tensor]) -> Dict[str, Any]:

return step_dict # Returning the metrics_logs with the loss


def validation_step(self, batch: Dict[str, Tensor]) -> Dict[str, Any]:
return self._general_step(batch=batch, step_name="val")

Expand All @@ -575,6 +574,12 @@ def _general_epoch_start(self, step_name: Literal["train", "val", "test"]) -> No
self.epoch_start_time[step_name] = time.time()
self.mean_time_tracker.reset()
self.mean_tput_tracker.reset()

def predict_step(self, batch: Dict[str, Tensor]) -> Dict[str, Any]:
preds = self.forward(batch) # The dictionary of predictions
targets_dict = batch.get("labels")

return preds, targets_dict


def _general_epoch_end(self, step_name: Literal["train", "val", "test"]) -> Dict[str, Tensor]:
Expand Down Expand Up @@ -628,12 +633,12 @@ def on_validation_epoch_end(self) -> None:
self._general_epoch_end(step_name="val")
return super().on_validation_epoch_end()


def on_test_epoch_start(self) -> None:
self._general_epoch_start(step_name="test")
return super().on_test_epoch_start()

def on_test_epoch_end(self) -> None:

self._general_epoch_end(step_name="test")
return super().on_test_epoch_end()

Expand Down
Binary file added tests/data/dummy_node_label_order_data.parquet
Binary file not shown.
Loading

0 comments on commit ce4f94d

Please sign in to comment.