Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tarepan committed Oct 24, 2023
1 parent 310f62c commit ce8a736
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
16 changes: 16 additions & 0 deletions hubconf_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Test torch.hub basic functionalities."""

from hubconf import utmos22_strong
from speechmos.utmos22.strong.model import UTMOS22Strong


def test_utmos22_strong_init():
"""Test `utmos22_strong` instantiation without weight load."""

# Test - progress=True
model = utmos22_strong(progress=True, pretrained=False)
assert isinstance(model, UTMOS22Strong), "UTMOS22Strong not properly instantiated."

# Test - progress=False
model = utmos22_strong(progress=False, pretrained=False)
assert isinstance(model, UTMOS22Strong), "UTMOS22Strong not properly instantiated."
47 changes: 47 additions & 0 deletions speechmos/utmos22/strong/model_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Test the utmos22-strong model."""

import torch

from .model import UTMOS22Strong


def test_model_init():
"""Test the `UTMOS22Strong` instantiation."""

# Preparation
UTMOS22Strong()

# Test
assert True, "UTMOS22Strong is not properly instantiated."


def test_model_forward():
"""Test the `UTMOS22Strong` forward run."""

# Preparation
model = UTMOS22Strong()
sr = 16000
ipt = torch.tensor([1. for _ in range(int(sr * 0.5))]).unsqueeze(0)

# Prerequesite Test
assert ipt.size() == (1, 8000), "Prerequesites are not satisfied."

# Test
model(ipt, sr)
assert True, "UTMOS22Strong is not properly forwarded."


def test_model_output_shape():
"""Test the `UTMOS22Strong` forward output shape."""

# Preparation
model = UTMOS22Strong()
sr = 16000
ipt = torch.tensor([1. for _ in range(int(sr * 0.5))]).unsqueeze(0)

# Prerequesite Test
assert ipt.size() == (1, 8000), "Prerequesites are not satisfied."

# Test
opt = model(ipt, sr)
assert opt.size() == (1,), "UTMOS22Strong is not properly forwarded."

0 comments on commit ce8a736

Please sign in to comment.