Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TypeHints] OneHotDegree #5667

Merged
merged 9 commits into from
Oct 13, 2022
Next Next commit
type hints for one_hot_degree
  • Loading branch information
sebastian-montero committed Oct 12, 2022
commit 3227fd865c8305690f41b3159a3ee9cb2ce79886
9 changes: 7 additions & 2 deletions torch_geometric/transforms/one_hot_degree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Union

import torch
import torch.nn.functional as F

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import degree
Expand All @@ -19,12 +22,14 @@ class OneHotDegree(BaseTransform):
cat (bool, optional): Concat node degrees to node features instead
of replacing them. (default: :obj:`True`)
"""
def __init__(self, max_degree, in_degree=False, cat=True):
def __init__(self, max_degree: int, in_degree: bool = False,
cat: bool = True):
self.max_degree = max_degree
self.in_degree = in_degree
self.cat = cat

def __call__(self, data):
def __call__(self, data: Union[Data,
HeteroData]) -> Union[Data, HeteroData]:
idx, x = data.edge_index[1 if self.in_degree else 0], data.x
deg = degree(idx, data.num_nodes, dtype=torch.long)
deg = F.one_hot(deg, num_classes=self.max_degree + 1).to(torch.float)
Expand Down