Skip to content

Commit

Permalink
[Type Hints] datasets.ModelNet (#5701)
Browse files Browse the repository at this point in the history
Added TypeHints

Co-authored-by: vsitaraman <vilayannur.sitaraman@hitachivantara.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
4 people authored Oct 15, 2022
1 parent 753d1e6 commit 2a64140
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))
- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))
- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702))
- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701] (https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702))
- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))
- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))
- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))
Expand Down
28 changes: 21 additions & 7 deletions torch_geometric/datasets/modelnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,17 @@
import os
import os.path as osp
import shutil
from typing import Callable, Dict, List, Optional, Tuple

import torch

from torch_geometric.data import InMemoryDataset, download_url, extract_zip
from torch import Tensor

from torch_geometric.data import (
Data,
InMemoryDataset,
download_url,
extract_zip,
)
from torch_geometric.io import read_off


Expand Down Expand Up @@ -51,23 +58,30 @@ class ModelNet(InMemoryDataset):
'40': 'http://modelnet.cs.princeton.edu/ModelNet40.zip'
}

def __init__(self, root, name='10', train=True, transform=None,
pre_transform=None, pre_filter=None):
def __init__(
self,
root: str,
name: str = '10',
train: bool = True,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
):
assert name in ['10', '40']
self.name = name
super().__init__(root, transform, pre_transform, pre_filter)
path = self.processed_paths[0] if train else self.processed_paths[1]
self.data, self.slices = torch.load(path)

@property
def raw_file_names(self):
def raw_file_names(self) -> List[str]:
return [
'bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor',
'night_stand', 'sofa', 'table', 'toilet'
]

@property
def processed_file_names(self):
def processed_file_names(self) -> List[str]:
return ['training.pt', 'test.pt']

def download(self):
Expand All @@ -87,7 +101,7 @@ def process(self):
torch.save(self.process_set('train'), self.processed_paths[0])
torch.save(self.process_set('test'), self.processed_paths[1])

def process_set(self, dataset):
def process_set(self, dataset: str) -> Tuple[Data, Dict[str, Tensor]]:
categories = glob.glob(osp.join(self.raw_dir, '*', ''))
categories = sorted([x.split(os.sep)[-2] for x in categories])

Expand Down

0 comments on commit 2a64140

Please sign in to comment.