Skip to content

Commit 32f690a

Browse files
committed
Drop torchvision dependency for unit tests
1 parent 8887fa7 commit 32f690a

3 files changed

Lines changed: 12 additions & 6 deletions

File tree

.github/workflows/pip.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
python-version: ${{ matrix.python-version }}
2626

2727
- name: Install PyTorch
28-
run: pip install torch torchvision pillow==6.1 --extra-index-url https://download.pytorch.org/whl/cpu
28+
run: pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
2929

3030
- name: Build and install
3131
run: pip install --verbose .[test]

setup.cfg

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,4 @@ where = src
2020

2121
[options.extras_require]
2222
test =
23-
pytest
24-
torchvision
23+
pytest

tests/powersgd_test.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
import torch
2-
import torchvision
32

43
from powersgd import PowerSGD, Config
54

5+
def build_model():
6+
return torch.nn.Sequential(
7+
torch.nn.Conv2d(3, 100, 3),
8+
torch.nn.ReLU(),
9+
torch.nn.Conv2d(100, 50, 5),
10+
torch.nn.Linear(50, 1)
11+
)
12+
613

714
def test_no_compression_in_the_beginning():
8-
model = torchvision.models.resnet50()
15+
model = build_model()
916
params = list(model.parameters())
1017
config = Config(
1118
rank=1,
@@ -29,7 +36,7 @@ def test_no_compression_in_the_beginning():
2936

3037
def test_error_feedback_mechanism():
3138
torch.set_default_dtype(torch.float64)
32-
model = torchvision.models.resnet50()
39+
model = build_model()
3340
params = list(model.parameters())
3441
config = Config(
3542
rank=2,

0 commit comments

Comments
 (0)