Skip to content

Commit

Permalink
Cap quadratic complexity of LinkPredPersonalization (#10058)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Feb 21, 2025
1 parent 328bb47 commit d802320
Showing 1 changed file with 26 additions and 5 deletions.
31 changes: 26 additions & 5 deletions torch_geometric/metrics/link_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,34 +715,54 @@ class LinkPredPersonalization(_LinkPredMetric):
Args:
k (int): The number of top-:math:`k` predictions to evaluate against.
max_src_nodes (int, optional): The maximum source nodes to consider to
compute pair-wise dissimilarity. If specified,
Personalization @ :math:`k` is approximated to avoid computation
blowup due to quadratic complexity. (default: :obj:`2**12`)
batch_size (int, optional): The batch size to determine how many pairs
of user recommendations should be processed at once.
(default: :obj:`2**16`)
"""
higher_is_better: bool = True

def __init__(self, k: int, batch_size: int = 2**16) -> None:
def __init__(
self,
k: int,
max_src_nodes: Optional[int] = 2**12,
batch_size: int = 2**16,
) -> None:
super().__init__(k)
self.max_src_nodes = max_src_nodes
self.batch_size = batch_size

if WITH_TORCHMETRICS:
self.add_state('preds', default=[], dist_reduce_fx='cat')
self.add_state('dev_tensor', torch.empty(0), dist_reduce_fx='sum')
self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')
else:
self.preds: List[Tensor] = []
self.register_buffer('dev_tensor', torch.empty(0))
self.register_buffer('total', torch.tensor(0))

def update(
self,
pred_index_mat: Tensor,
edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],
edge_label_weight: Optional[Tensor] = None,
) -> None:

# NOTE Move to CPU to avoid memory blowup.
self.preds.append(pred_index_mat[:, :self.k].cpu())
pred_index_mat = pred_index_mat[:, :self.k].cpu()

if self.max_src_nodes is None:
self.preds.append(pred_index_mat)
self.total += pred_index_mat.size(0)
elif self.total < self.max_src_nodes:
remaining = int(self.max_src_nodes - self.total)
pred_index_mat = pred_index_mat[:remaining]
self.preds.append(pred_index_mat)
self.total += pred_index_mat.size(0)

def compute(self) -> Tensor:
device = self.dev_tensor.device
device = self.total.device
score = torch.tensor(0.0, device=device)
total = torch.tensor(0, device=device)

Expand Down Expand Up @@ -786,6 +806,7 @@ def compute(self) -> Tensor:

def _reset(self) -> None:
self.preds = []
self.total.zero_()


class LinkPredAveragePopularity(_LinkPredMetric):
Expand Down

0 comments on commit d802320

Please sign in to comment.