Skip to content

Commit

Permalink
add tests and fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
timothyxp committed Oct 12, 2021
1 parent 1985921 commit 4a98b8f
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 2 deletions.
8 changes: 6 additions & 2 deletions hw_asr/logger/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ def add_image(self, scalar_name, image):
self.scalar_name(scalar_name): self.wandb.Image(image)
}, step=self.step)

def add_audio(self, scalar_name, audio):
def add_audio(self, scalar_name, audio, sample_rate=None):
audio = audio.detach().cpu().numpy().T
self.wandb.log({
self.scalar_name(scalar_name): self.wandb.Audio(audio)
self.scalar_name(scalar_name): self.wandb.Audio(audio, sample_rate=sample_rate)
}, step=self.step)

def add_text(self, scalar_name, text):
Expand All @@ -79,6 +80,9 @@ def add_histogram(self, scalar_name, hist, bins=None):
self.scalar_name(scalar_name): hist
}, step=self.step)

def add_images(self, scalar_name, images):
raise NotImplementedError()

def add_pr_curve(self, scalar_name, scalar):
raise NotImplementedError()

Expand Down
73 changes: 73 additions & 0 deletions hw_asr/tests/test_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import unittest
from hw_asr.logger.tensorboard import TensorboardWriter
from hw_asr.logger.utils import plot_spectrogram_to_buf
from hw_asr.logger.wandb import WanDBWriter
from hw_asr.utils.parse_config import ConfigParser
import shutil
import torchaudio
from torchvision.transforms import ToTensor
from pathlib import Path
import numpy as np
import torch
import PIL


class TestVisualization(unittest.TestCase):
def test_visualiaers(self):
log_dir = str(Path(__file__).parent / "logs_dir")

try:
config = ConfigParser.get_default_configs()
logger = config.get_logger("test")

tensorboard = TensorboardWriter(log_dir, logger, True)
wandb = WanDBWriter(config, logger)

test_methods = [
"add_scalar",
"add_scalars",
"add_image",
"add_audio",
"add_text",
"add_histogram"
]

audio_path = Path(__file__).parent.parent.parent / "test_data" / "audio" / "84-121550-0000.flac"
audio, sr = torchaudio.load(audio_path)
print(audio.shape)
wave2spec = config.init_obj(
config["preprocessing"]["spectrogram"],
torchaudio.transforms,
)

wave = wave2spec(audio)
image = ToTensor()(PIL.Image.open(plot_spectrogram_to_buf(wave.squeeze(0).log())))
print(image.shape)

hist = torch.from_numpy(np.asarray([1, 2, 3, 4]))

test_data = [
1,
{"test1": 1, "test2": 2},
image,
audio,
"test",
hist
]

for method, value in zip(test_methods, test_data):
kwargs = {}
if method == 'add_audio':
kwargs = {'sample_rate': sr}
elif method == 'add_histogram':
kwargs = {'bins': 'auto'}

logger.info(f"test {method}")
getattr(tensorboard, method)(method, value, **kwargs)
getattr(wandb, method)(method, value, **kwargs)

finally:
shutil.rmtree(log_dir)



0 comments on commit 4a98b8f

Please sign in to comment.