Skip to content

Commit

Permalink
test: update test parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Aug 29, 2022
1 parent 9976f96 commit 4c726f2
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
4 changes: 2 additions & 2 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def assert_all_close(
from torch.testing._comparison import get_tolerances

rtol, atol = get_tolerances(actual, expected, rtol=rtol, atol=atol)
rtol *= 10 * NUM_UPDATES
atol *= 10 * NUM_UPDATES
rtol *= 4 * NUM_UPDATES
atol *= 4 * NUM_UPDATES

torch.testing.assert_close(
actual,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_sgd(


@helpers.parametrize(
dtype=[torch.float64, torch.float32],
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
betas=[(0.9, 0.999), (0.95, 0.9995)],
eps=[1e-8],
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_adam(


@helpers.parametrize(
dtype=[torch.float64, torch.float32],
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
betas=[(0.9, 0.999), (0.95, 0.9995)],
eps=[1e-8],
Expand Down Expand Up @@ -206,7 +206,7 @@ def test_adam_accelerated_cpu(

@pytest.mark.skipif(not torch.cuda.is_available(), reason='No CUDA device available.')
@helpers.parametrize(
dtype=[torch.float64, torch.float32],
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
betas=[(0.9, 0.999), (0.95, 0.9995)],
eps=[1e-8],
Expand Down Expand Up @@ -267,7 +267,7 @@ def test_adam_accelerated_cuda(


@helpers.parametrize(
dtype=[torch.float64, torch.float32],
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
alpha=[0.9, 0.99],
eps=[1e-8],
Expand Down
8 changes: 4 additions & 4 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_SGD(


@helpers.parametrize(
dtype=[torch.float64, torch.float32],
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
betas=[(0.9, 0.999), (0.95, 0.9995)],
eps=[1e-8],
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_Adam(


@helpers.parametrize(
dtype=[torch.float64, torch.float32],
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
betas=[(0.9, 0.999), (0.95, 0.9995)],
eps=[1e-8],
Expand Down Expand Up @@ -196,7 +196,7 @@ def test_Adam_accelerated_cpu(

@pytest.mark.skipif(not torch.cuda.is_available(), reason='No CUDA device available.')
@helpers.parametrize(
dtype=[torch.float64, torch.float32],
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
betas=[(0.9, 0.999), (0.95, 0.9995)],
eps=[1e-8],
Expand Down Expand Up @@ -254,7 +254,7 @@ def test_Adam_accelerated_cuda(


@helpers.parametrize(
dtype=[torch.float64, torch.float32],
dtype=[torch.float64],
lr=[1e-2, 1e-3, 1e-4],
alpha=[0.9, 0.99],
eps=[1e-8],
Expand Down

0 comments on commit 4c726f2

Please sign in to comment.