Skip to content

Commit 1a3ca94

Browse files
authored
Removed custom events from Image Classification template (#98)
* Removed custom events from Image Classification template * Added skip for slow tests on dataset (downloading) Fixed failing tests for image classification
1 parent a9d774b commit 1a3ca94

File tree

8 files changed

+17
-98
lines changed

8 files changed

+17
-98
lines changed

.github/run_code_style.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ if [ $1 == "lint" ]; then
77
isort app templates/*/_sidebar.py tests --check --settings pyproject.toml
88
black app templates/*/_sidebar.py tests --check --config pyproject.toml
99
elif [ $1 == "fmt" ]; then
10-
isort app templates/*/_sidebar.py tests --color --settings pyproject.toml
11-
black app templates/*/_sidebar.py tests --config pyproject.toml
10+
isort app templates/*/_sidebar.py templates/*/test_all.py tests --color --settings pyproject.toml
11+
black app templates/*/_sidebar.py templates/*/test_all.py tests --config pyproject.toml
1212
elif [ $1 == "install" ]; then
1313
pip install flake8 "black==20.8b1" "isort==5.7.0"
1414
fi

.github/run_test.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ set -xeuo pipefail
55
if [ $1 == "generate" ]; then
66
python ./tests/generate.py
77
elif [ $1 == "unittest" ]; then
8+
python ./tests/generate.py
89
for dir in $(find ./tests/dist -type d -mindepth 1 -maxdepth 1 -not -path "./tests/dist/launch" -not -path "./tests/dist/spawn")
910
do
1011
cd $dir
@@ -18,6 +19,7 @@ elif [ $1 == "unittest" ]; then
1819
cd ../../../
1920
done
2021
elif [ $1 == "default" ]; then
22+
python ./tests/generate.py
2123
for file in $(find ./tests/dist -iname "main.py" -not -path "./tests/dist/launch/*" -not -path "./tests/dist/spawn/*" -not -path "./tests/dist/single/*")
2224
do
2325
python $file \
@@ -28,6 +30,7 @@ elif [ $1 == "default" ]; then
2830
--eval_epoch_length 10
2931
done
3032
elif [ $1 == "launch" ]; then
33+
python ./tests/generate.py
3134
for file in $(find ./tests/dist/launch -iname "main.py" -not -path "./tests/dist/launch/single/*")
3235
do
3336
python -m torch.distributed.launch \
@@ -41,6 +44,7 @@ elif [ $1 == "launch" ]; then
4144
--log_every_iters 2
4245
done
4346
elif [ $1 == "spawn" ]; then
47+
python ./tests/generate.py
4448
for file in $(find ./tests/dist/spawn -iname "main.py" -not -path "./tests/dist/spawn/single/*")
4549
do
4650
python $file \

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ jobs:
4141
- run: pip install -r requirements.txt --progress-bar off
4242
- run: pip install -r requirements-dev.txt -f https://download.pytorch.org/whl/cpu/torch_stable.html --progress-bar off
4343
- run: python -m torch.utils.collect_env
44-
- run: bash .github/run_test.sh generate
4544
- run: bash .github/run_test.sh unittest
4645
- run: bash .github/run_test.sh default
4746
- run: bash .github/run_test.sh spawn
4847
- run: bash .github/run_test.sh launch
4948
env:
49+
RUN_SLOW_TESTS: 1
5050
OMP_NUM_THREADS: 1
5151

5252
lint:

templates/gan/test_all.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import os
12
from argparse import Namespace
23
from numbers import Number
34
from typing import Iterable
45

56
import ignite.distributed as idist
7+
import pytest
68
import torch
79
from datasets import get_datasets
810
from ignite.engine import Engine
@@ -23,6 +25,7 @@ def set_up():
2325
return model, optimizer, device, loss_fn, batch
2426

2527

28+
@pytest.mark.skipif(os.getenv("RUN_SLOW_TESTS", 0) == 0, reason="Skip slow tests")
2629
def test_get_datasets(tmp_path):
2730
dataset, _ = get_datasets("cifar10", tmp_path)
2831

templates/image_classification/_test_internal.py

Lines changed: 0 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,8 @@
2424
from test_all import set_up
2525
from torch import nn, optim
2626
from trainers import (
27-
TrainEvents,
2827
create_trainers,
2928
evaluate_function,
30-
train_events_to_attr,
3129
train_function,
3230
)
3331
from utils import (
@@ -92,88 +90,6 @@ def test_get_logger(tmp_path):
9290
assert isinstance(logger_handler, types), "Should be Ignite provided loggers or None"
9391

9492

95-
def test_train_fn():
96-
model, optimizer, device, loss_fn, batch = set_up()
97-
engine = Engine(lambda e, b: 1)
98-
engine.register_events(*TrainEvents, event_to_attr=train_events_to_attr)
99-
backward = MagicMock()
100-
optim = MagicMock()
101-
engine.add_event_handler(TrainEvents.BACKWARD_COMPLETED, backward)
102-
engine.add_event_handler(TrainEvents.OPTIM_STEP_COMPLETED, optim)
103-
config = Namespace(use_amp=False)
104-
output = train_function(config, engine, batch, model, loss_fn, optimizer, device)
105-
assert isinstance(output, dict)
106-
assert hasattr(engine.state, "backward_completed")
107-
assert hasattr(engine.state, "optim_step_completed")
108-
assert engine.state.backward_completed == 1
109-
assert engine.state.optim_step_completed == 1
110-
assert backward.call_count == 1
111-
assert optim.call_count == 1
112-
assert backward.called
113-
assert optim.called
114-
115-
116-
def test_train_fn_event_filter():
117-
model, optimizer, device, loss_fn, batch = set_up()
118-
config = Namespace(use_amp=False)
119-
engine = Engine(lambda e, b: train_function(config, e, b, model, loss_fn, optimizer, device))
120-
engine.register_events(*TrainEvents, event_to_attr=train_events_to_attr)
121-
backward = MagicMock()
122-
optim = MagicMock()
123-
engine.add_event_handler(TrainEvents.BACKWARD_COMPLETED(event_filter=lambda _, x: (x % 2 == 0) or x == 3), backward)
124-
engine.add_event_handler(TrainEvents.OPTIM_STEP_COMPLETED(event_filter=lambda _, x: (x % 2 == 0) or x == 3), optim)
125-
engine.run([batch] * 5)
126-
assert hasattr(engine.state, "backward_completed")
127-
assert hasattr(engine.state, "optim_step_completed")
128-
assert engine.state.backward_completed == 5
129-
assert engine.state.optim_step_completed == 5
130-
assert backward.call_count == 3
131-
assert optim.call_count == 3
132-
assert backward.called
133-
assert optim.called
134-
135-
136-
def test_train_fn_every():
137-
model, optimizer, device, loss_fn, batch = set_up()
138-
139-
config = Namespace(use_amp=False)
140-
engine = Engine(lambda e, b: train_function(config, e, b, model, loss_fn, optimizer, device))
141-
engine.register_events(*TrainEvents, event_to_attr=train_events_to_attr)
142-
backward = MagicMock()
143-
optim = MagicMock()
144-
engine.add_event_handler(TrainEvents.BACKWARD_COMPLETED(every=2), backward)
145-
engine.add_event_handler(TrainEvents.OPTIM_STEP_COMPLETED(every=2), optim)
146-
engine.run([batch] * 5)
147-
assert hasattr(engine.state, "backward_completed")
148-
assert hasattr(engine.state, "optim_step_completed")
149-
assert engine.state.backward_completed == 5
150-
assert engine.state.optim_step_completed == 5
151-
assert backward.call_count == 2
152-
assert optim.call_count == 2
153-
assert backward.called
154-
assert optim.called
155-
156-
157-
def test_train_fn_once():
158-
model, optimizer, device, loss_fn, batch = set_up()
159-
config = Namespace(use_amp=False)
160-
engine = Engine(lambda e, b: train_function(config, e, b, model, loss_fn, optimizer, device))
161-
engine.register_events(*TrainEvents, event_to_attr=train_events_to_attr)
162-
backward = MagicMock()
163-
optim = MagicMock()
164-
engine.add_event_handler(TrainEvents.BACKWARD_COMPLETED(once=3), backward)
165-
engine.add_event_handler(TrainEvents.OPTIM_STEP_COMPLETED(once=3), optim)
166-
engine.run([batch] * 5)
167-
assert hasattr(engine.state, "backward_completed")
168-
assert hasattr(engine.state, "optim_step_completed")
169-
assert engine.state.backward_completed == 5
170-
assert engine.state.optim_step_completed == 5
171-
assert backward.call_count == 1
172-
assert optim.call_count == 1
173-
assert backward.called
174-
assert optim.called
175-
176-
17793
def test_evaluate_fn():
17894
model, optimizer, device, loss_fn, batch = set_up()
17995
engine = Engine(lambda e, b: 1)
@@ -193,8 +109,6 @@ def test_create_trainers():
193109
)
194110
assert isinstance(trainer, Engine)
195111
assert isinstance(evaluator, Engine)
196-
assert hasattr(trainer.state, "backward_completed")
197-
assert hasattr(trainer.state, "optim_step_completed")
198112

199113

200114
def test_get_default_parser():

templates/image_classification/test_all.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import os
12
from argparse import Namespace
23
from numbers import Number
34
from typing import Iterable
45

56
import ignite.distributed as idist
7+
import pytest
68
import torch
79
from datasets import get_datasets
810
from ignite.contrib.handlers.param_scheduler import ParamScheduler
@@ -25,6 +27,7 @@ def set_up():
2527
return model, optimizer, device, loss_fn, batch
2628

2729

30+
@pytest.mark.skipif(os.getenv("RUN_SLOW_TESTS", 0) == 0, reason="Skip slow tests")
2831
def test_get_datasets(tmp_path):
2932
train_ds, eval_ds = get_datasets(tmp_path)
3033

templates/image_classification/trainers.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
from torch.cuda.amp import autocast
99
from torch.optim.optimizer import Optimizer
1010

11-
{% include "_events.py" %}
12-
1311

1412
# Edit below functions the way how the model will be training
1513

@@ -62,13 +60,7 @@ def train_function(
6260
loss = loss_fn(outputs, targets)
6361

6462
loss.backward()
65-
engine.state.backward_completed += 1
66-
engine.fire_event(TrainEvents.BACKWARD_COMPLETED)
67-
6863
optimizer.step()
69-
engine.state.optim_step_completed += 1
70-
engine.fire_event(TrainEvents.OPTIM_STEP_COMPLETED)
71-
7264
optimizer.zero_grad()
7365

7466
loss_value = loss.item()
@@ -163,5 +155,4 @@ def create_trainers(config, model, optimizer, loss_fn, device) -> Tuple[Engine,
163155
device=device
164156
)
165157
)
166-
trainer.register_events(*TrainEvents, event_to_attr=train_events_to_attr)
167158
return trainer, evaluator

templates/text_classification/test_all.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import os
12
from argparse import Namespace
23
from numbers import Number
34
from typing import Iterable
45

56
import ignite.distributed as idist
7+
import pytest
68
import torch
79
from dataset import get_dataflow, get_dataset
810
from ignite.contrib.handlers.param_scheduler import ParamScheduler
@@ -41,6 +43,7 @@ def test_initialize():
4143
assert isinstance(lr_scheduler, (_LRScheduler, ParamScheduler))
4244

4345

46+
@pytest.mark.skipif(os.getenv("RUN_SLOW_TESTS", 0) == 0, reason="Skip slow tests")
4447
def test_get_dataflow():
4548
config = Namespace(
4649
data_dir="/tmp/data",
@@ -55,6 +58,7 @@ def test_get_dataflow():
5558
assert isinstance(test_loader, DataLoader)
5659

5760

61+
@pytest.mark.skipif(os.getenv("RUN_SLOW_TESTS", 0) == 0, reason="Skip slow tests")
5862
def test_get_dataset():
5963
cache_dir = "/tmp"
6064
tokenizer_name = "bert-base-uncased"

0 commit comments

Comments
 (0)