Skip to content

Removed custom events from Image Classification template #98

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/run_code_style.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ if [ $1 == "lint" ]; then
isort app templates/*/_sidebar.py tests --check --settings pyproject.toml
black app templates/*/_sidebar.py tests --check --config pyproject.toml
elif [ $1 == "fmt" ]; then
isort app templates/*/_sidebar.py tests --color --settings pyproject.toml
black app templates/*/_sidebar.py tests --config pyproject.toml
isort app templates/*/_sidebar.py templates/*/test_all.py tests --color --settings pyproject.toml
black app templates/*/_sidebar.py templates/*/test_all.py tests --config pyproject.toml
elif [ $1 == "install" ]; then
pip install flake8 "black==20.8b1" "isort==5.7.0"
fi
4 changes: 4 additions & 0 deletions .github/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ set -xeuo pipefail
if [ $1 == "generate" ]; then
python ./tests/generate.py
elif [ $1 == "unittest" ]; then
python ./tests/generate.py
for dir in $(find ./tests/dist -type d -mindepth 1 -maxdepth 1 -not -path "./tests/dist/launch" -not -path "./tests/dist/spawn")
do
cd $dir
Expand All @@ -18,6 +19,7 @@ elif [ $1 == "unittest" ]; then
cd ../../../
done
elif [ $1 == "default" ]; then
python ./tests/generate.py
for file in $(find ./tests/dist -iname "main.py" -not -path "./tests/dist/launch/*" -not -path "./tests/dist/spawn/*" -not -path "./tests/dist/single/*")
do
python $file \
Expand All @@ -28,6 +30,7 @@ elif [ $1 == "default" ]; then
--eval_epoch_length 10
done
elif [ $1 == "launch" ]; then
python ./tests/generate.py
for file in $(find ./tests/dist/launch -iname "main.py" -not -path "./tests/dist/launch/single/*")
do
python -m torch.distributed.launch \
Expand All @@ -41,6 +44,7 @@ elif [ $1 == "launch" ]; then
--log_every_iters 2
done
elif [ $1 == "spawn" ]; then
python ./tests/generate.py
for file in $(find ./tests/dist/spawn -iname "main.py" -not -path "./tests/dist/spawn/single/*")
do
python $file \
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ jobs:
- run: pip install -r requirements.txt --progress-bar off
- run: pip install -r requirements-dev.txt -f https://download.pytorch.org/whl/cpu/torch_stable.html --progress-bar off
- run: python -m torch.utils.collect_env
- run: bash .github/run_test.sh generate
- run: bash .github/run_test.sh unittest
- run: bash .github/run_test.sh default
- run: bash .github/run_test.sh spawn
- run: bash .github/run_test.sh launch
env:
RUN_SLOW_TESTS: 1
OMP_NUM_THREADS: 1

lint:
Expand Down
3 changes: 3 additions & 0 deletions templates/gan/test_all.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
from argparse import Namespace
from numbers import Number
from typing import Iterable

import ignite.distributed as idist
import pytest
import torch
from datasets import get_datasets
from ignite.engine import Engine
Expand All @@ -23,6 +25,7 @@ def set_up():
return model, optimizer, device, loss_fn, batch


@pytest.mark.skipif(os.getenv("RUN_SLOW_TESTS", 0) == 0, reason="Skip slow tests")
def test_get_datasets(tmp_path):
dataset, _ = get_datasets("cifar10", tmp_path)

Expand Down
86 changes: 0 additions & 86 deletions templates/image_classification/_test_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@
from test_all import set_up
from torch import nn, optim
from trainers import (
TrainEvents,
create_trainers,
evaluate_function,
train_events_to_attr,
train_function,
)
from utils import (
Expand Down Expand Up @@ -92,88 +90,6 @@ def test_get_logger(tmp_path):
assert isinstance(logger_handler, types), "Should be Ignite provided loggers or None"


def test_train_fn():
model, optimizer, device, loss_fn, batch = set_up()
engine = Engine(lambda e, b: 1)
engine.register_events(*TrainEvents, event_to_attr=train_events_to_attr)
backward = MagicMock()
optim = MagicMock()
engine.add_event_handler(TrainEvents.BACKWARD_COMPLETED, backward)
engine.add_event_handler(TrainEvents.OPTIM_STEP_COMPLETED, optim)
config = Namespace(use_amp=False)
output = train_function(config, engine, batch, model, loss_fn, optimizer, device)
assert isinstance(output, dict)
assert hasattr(engine.state, "backward_completed")
assert hasattr(engine.state, "optim_step_completed")
assert engine.state.backward_completed == 1
assert engine.state.optim_step_completed == 1
assert backward.call_count == 1
assert optim.call_count == 1
assert backward.called
assert optim.called


def test_train_fn_event_filter():
model, optimizer, device, loss_fn, batch = set_up()
config = Namespace(use_amp=False)
engine = Engine(lambda e, b: train_function(config, e, b, model, loss_fn, optimizer, device))
engine.register_events(*TrainEvents, event_to_attr=train_events_to_attr)
backward = MagicMock()
optim = MagicMock()
engine.add_event_handler(TrainEvents.BACKWARD_COMPLETED(event_filter=lambda _, x: (x % 2 == 0) or x == 3), backward)
engine.add_event_handler(TrainEvents.OPTIM_STEP_COMPLETED(event_filter=lambda _, x: (x % 2 == 0) or x == 3), optim)
engine.run([batch] * 5)
assert hasattr(engine.state, "backward_completed")
assert hasattr(engine.state, "optim_step_completed")
assert engine.state.backward_completed == 5
assert engine.state.optim_step_completed == 5
assert backward.call_count == 3
assert optim.call_count == 3
assert backward.called
assert optim.called


def test_train_fn_every():
model, optimizer, device, loss_fn, batch = set_up()

config = Namespace(use_amp=False)
engine = Engine(lambda e, b: train_function(config, e, b, model, loss_fn, optimizer, device))
engine.register_events(*TrainEvents, event_to_attr=train_events_to_attr)
backward = MagicMock()
optim = MagicMock()
engine.add_event_handler(TrainEvents.BACKWARD_COMPLETED(every=2), backward)
engine.add_event_handler(TrainEvents.OPTIM_STEP_COMPLETED(every=2), optim)
engine.run([batch] * 5)
assert hasattr(engine.state, "backward_completed")
assert hasattr(engine.state, "optim_step_completed")
assert engine.state.backward_completed == 5
assert engine.state.optim_step_completed == 5
assert backward.call_count == 2
assert optim.call_count == 2
assert backward.called
assert optim.called


def test_train_fn_once():
model, optimizer, device, loss_fn, batch = set_up()
config = Namespace(use_amp=False)
engine = Engine(lambda e, b: train_function(config, e, b, model, loss_fn, optimizer, device))
engine.register_events(*TrainEvents, event_to_attr=train_events_to_attr)
backward = MagicMock()
optim = MagicMock()
engine.add_event_handler(TrainEvents.BACKWARD_COMPLETED(once=3), backward)
engine.add_event_handler(TrainEvents.OPTIM_STEP_COMPLETED(once=3), optim)
engine.run([batch] * 5)
assert hasattr(engine.state, "backward_completed")
assert hasattr(engine.state, "optim_step_completed")
assert engine.state.backward_completed == 5
assert engine.state.optim_step_completed == 5
assert backward.call_count == 1
assert optim.call_count == 1
assert backward.called
assert optim.called


def test_evaluate_fn():
model, optimizer, device, loss_fn, batch = set_up()
engine = Engine(lambda e, b: 1)
Expand All @@ -193,8 +109,6 @@ def test_create_trainers():
)
assert isinstance(trainer, Engine)
assert isinstance(evaluator, Engine)
assert hasattr(trainer.state, "backward_completed")
assert hasattr(trainer.state, "optim_step_completed")


def test_get_default_parser():
Expand Down
3 changes: 3 additions & 0 deletions templates/image_classification/test_all.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
from argparse import Namespace
from numbers import Number
from typing import Iterable

import ignite.distributed as idist
import pytest
import torch
from datasets import get_datasets
from ignite.contrib.handlers.param_scheduler import ParamScheduler
Expand All @@ -25,6 +27,7 @@ def set_up():
return model, optimizer, device, loss_fn, batch


@pytest.mark.skipif(os.getenv("RUN_SLOW_TESTS", 0) == 0, reason="Skip slow tests")
def test_get_datasets(tmp_path):
train_ds, eval_ds = get_datasets(tmp_path)

Expand Down
9 changes: 0 additions & 9 deletions templates/image_classification/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from torch.cuda.amp import autocast
from torch.optim.optimizer import Optimizer

{% include "_events.py" %}


# Edit below functions the way how the model will be training

Expand Down Expand Up @@ -63,13 +61,7 @@ def train_function(
loss = loss_fn(outputs, targets)

loss.backward()
engine.state.backward_completed += 1
engine.fire_event(TrainEvents.BACKWARD_COMPLETED)

optimizer.step()
engine.state.optim_step_completed += 1
engine.fire_event(TrainEvents.OPTIM_STEP_COMPLETED)

optimizer.zero_grad()

loss_value = loss.item()
Expand Down Expand Up @@ -164,5 +156,4 @@ def create_trainers(config, model, optimizer, loss_fn, device) -> Tuple[Engine,
device=device
)
)
trainer.register_events(*TrainEvents, event_to_attr=train_events_to_attr)
return trainer, evaluator
4 changes: 4 additions & 0 deletions templates/text_classification/test_all.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
from argparse import Namespace
from numbers import Number
from typing import Iterable

import ignite.distributed as idist
import pytest
import torch
from dataset import get_dataflow, get_dataset
from ignite.contrib.handlers.param_scheduler import ParamScheduler
Expand Down Expand Up @@ -41,6 +43,7 @@ def test_initialize():
assert isinstance(lr_scheduler, (_LRScheduler, ParamScheduler))


@pytest.mark.skipif(os.getenv("RUN_SLOW_TESTS", 0) == 0, reason="Skip slow tests")
def test_get_dataflow():
config = Namespace(
data_dir="/tmp/data",
Expand All @@ -55,6 +58,7 @@ def test_get_dataflow():
assert isinstance(test_loader, DataLoader)


@pytest.mark.skipif(os.getenv("RUN_SLOW_TESTS", 0) == 0, reason="Skip slow tests")
def test_get_dataset():
cache_dir = "/tmp"
tokenizer_name = "bert-base-uncased"
Expand Down