Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
C1rN09 committed Sep 8, 2022
1 parent 6653191 commit b54e8a5
Showing 1 changed file with 97 additions and 0 deletions.
97 changes: 97 additions & 0 deletions tests/test_hooks/test_checkpoint_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,69 @@
from unittest.mock import Mock, patch

import pytest
import torch
import torch.nn as nn
from torch.utils.data import Dataset

from mmengine.evaluator import BaseMetric
from mmengine.hooks import CheckpointHook
from mmengine.logging import MessageHub
from mmengine.model import BaseModel
from mmengine.optim import OptimWrapper
from mmengine.runner import Runner


class ToyModel(BaseModel):

def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 1)

def forward(self, inputs, data_sample, mode='tensor'):
labels = torch.stack(data_sample)
inputs = torch.stack(inputs)
outputs = self.linear(inputs)
if mode == 'tensor':
return outputs
elif mode == 'loss':
loss = (labels - outputs).sum()
outputs = dict(loss=loss)
return outputs
else:
return outputs


class DummyDataset(Dataset):
METAINFO = dict() # type: ignore
data = torch.randn(12, 2)
label = torch.ones(12)

@property
def metainfo(self):
return self.METAINFO

def __len__(self):
return self.data.size(0)

def __getitem__(self, index):
return dict(inputs=self.data[index], data_sample=self.label[index])


class TriangleMetric(BaseMetric):

def __init__(self, length):
super().__init__()
self.length = length
self.best_idx = length // 2
self.cur_idx = 0

def process(self, *args, **kwargs):
pass

def compute_metrics(self, *args, **kwargs):
self.cur_idx += 1
acc = 1.0 - abs(self.cur_idx - self.best_idx) / self.length
return dict(acc=acc)


class MockPetrel:
Expand Down Expand Up @@ -370,3 +430,40 @@ def test_after_train_iter(self, tmp_path):
checkpoint_hook.before_train(runner)
checkpoint_hook.after_train_iter(runner, batch_idx=batch_idx)
assert not os.path.exists(f'{work_dir}/iter_8.pth')

def test_with_runner(self, tmp_path):
max_epoch = 10
work_dir = osp.join(str(tmp_path), 'runner_test')
tmpl = '{}.pth'
save_interval = 2
checkpoint_cfg = dict(
type='CheckpointHook',
interval=save_interval,
filename_tmpl=tmpl,
by_epoch=True)
runner = Runner(
model=ToyModel(),
work_dir=work_dir,
train_dataloader=dict(
dataset=DummyDataset(),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=3,
num_workers=0),
val_dataloader=dict(
dataset=DummyDataset(),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=3,
num_workers=0),
val_evaluator=dict(type=TriangleMetric, length=max_epoch),
optim_wrapper=OptimWrapper(
torch.optim.Adam(ToyModel().parameters())),
train_cfg=dict(
by_epoch=True, max_epochs=max_epoch, val_interval=1),
val_cfg=dict(),
default_hooks=dict(checkpoint=checkpoint_cfg))
runner.train()
for epoch in range(max_epoch):
if epoch % save_interval != 0 or epoch == 0:
continue
path = osp.join(work_dir, tmpl.format(epoch))
assert osp.isfile(path=path)

0 comments on commit b54e8a5

Please sign in to comment.