diff --git a/src/transformers/models/graphormer/collating_graphormer.py b/src/transformers/models/graphormer/collating_graphormer.py index e2cccc6668a417..58ce602ea28de1 100644 --- a/src/transformers/models/graphormer/collating_graphormer.py +++ b/src/transformers/models/graphormer/collating_graphormer.py @@ -129,6 +129,6 @@ def __call__(self, features: List[dict]) -> Dict[str, Any]: else: # binary classification batch["labels"] = torch.from_numpy(np.concatenate([i["labels"] for i in features])) else: # multi task classification, left to float to keep the NaNs - batch["labels"] = torch.from_numpy(np.stack([i["labels"] for i in features], dim=0)) + batch["labels"] = torch.from_numpy(np.stack([i["labels"] for i in features], axis=0)) return batch