Skip to content

Commit

Permalink
Fix Python 3.10 tests (pyg-team#4982)
Browse files Browse the repository at this point in the history
* fix test

* typo

* changelog

* linting
  • Loading branch information
rusty1s authored Jul 14, 2022
1 parent 42876ab commit a7224df
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 1 deletion.
2 changes: 2 additions & 0 deletions .github/workflows/full_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ jobs:
include:
- torch-version: 1.12.0
torchvision-version: 0.13.0
- os: windows-latest
python-version: '3.7'

steps:
- uses: actions/checkout@v3
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))
- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))
### Changed
- Fixed packaging tests for Python 3.10 ([4982](https://github.com/pyg-team/pytorch_geometric/pull/4982))
- Changed `act_dict` (part of `graphgym`) to create individual instances instead of reusing the same ones everywhere ([4978](https://github.com/pyg-team/pytorch_geometric/pull/4978))
- Fixed issue where one-hot tensors were passed to `F.one_hot` ([4970](https://github.com/pyg-team/pytorch_geometric/pull/4970))
- Fixed `bool` arugments in `argparse` in `benchmark/` ([#4967](https://github.com/pyg-team/pytorch_geometric/pull/4967))
Expand Down
2 changes: 2 additions & 0 deletions test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import LayerNorm
from torch_geometric.nn.models import GAT, GCN, GIN, PNA, GraphSAGE
from torch_geometric.testing import withPython

out_dims = [None, 8]
dropouts = [0.0, 0.5]
Expand Down Expand Up @@ -125,6 +126,7 @@ def test_basic_gnn_inference(get_dataset, jk):
assert 'n_id' not in data


@withPython('3.7', '3.8', '3.9') # Packaging does not support Python 3.10 yet.
def test_packaging():
os.makedirs(torch.hub._get_torch_home(), exist_ok=True)

Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .decorators import (is_full_test, onlyFullTest, withPackage,
from .decorators import (is_full_test, onlyFullTest, withPython, withPackage,
withRegisteredOp, withCUDA)

__all__ = [
'is_full_test',
'onlyFullTest',
'withPython',
'withPackage',
'withRegisteredOp',
'withCUDA',
Expand Down
15 changes: 15 additions & 0 deletions torch_geometric/testing/decorators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
from importlib.util import find_spec
from typing import Callable

Expand All @@ -20,6 +21,20 @@ def onlyFullTest(func: Callable) -> Callable:
)(func)


def withPython(*args) -> Callable:
r"""A decorator to skip tests for any Python version not listed."""
def decorator(func: Callable) -> Callable:
import pytest

python_version = f'{sys.version_info.major}.{sys.version_info.minor}'
return pytest.mark.skipif(
python_version not in args,
reason=f"Python {python_version} not supported",
)(func)

return decorator


def withPackage(*args) -> Callable:
r"""A decorator to skip tests if certain packages are not installed."""
na_packages = set(arg for arg in args if find_spec(arg) is None)
Expand Down

0 comments on commit a7224df

Please sign in to comment.