Skip to content

Commit 3effbed

Browse files
committed
add helper script
1 parent d5d0a47 commit 3effbed

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

repro_dsp.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
from torchaudio.prototype.functional import oscillator_bank, adsr_envelope
3+
4+
import matplotlib.pyplot as plt
5+
6+
7+
def osci():
8+
sample_rate = 8000
9+
num_frames = 8000 * 3
10+
11+
freq = torch.ones((num_frames, 1)) * torch.tensor([[2000]]) # , 3000]])
12+
amps = torch.ones_like(freq)
13+
waveform32 = oscillator_bank(freq, amps, sample_rate)
14+
15+
freq = freq.to(torch.float64)
16+
amps = freq.to(torch.float64)
17+
waveform64 = oscillator_bank(freq, amps, sample_rate)
18+
19+
fig, axes = plt.subplots(2, 1, sharex=True, sharey=True)
20+
fig.suptitle("Precision and waveform generated by oscillator_bank")
21+
_, _, _, cax = axes[0].specgram(waveform32, Fs=sample_rate)
22+
axes[0].set(title="float32", ylabel="Frequency [Hz]")
23+
fig.colorbar(cax)
24+
25+
_, _, _, cax = axes[1].specgram(waveform64, Fs=sample_rate)
26+
axes[1].set(title="float64", xlabel="time [s]", ylabel="Frequency [Hz]")
27+
plt.colorbar(cax)
28+
plt.tight_layout()
29+
plt.subplots_adjust(left=0.2, top=0.85, right=0.9,bottom=0.15)
30+
fig.savefig("oscillator_precision.png", dpi=200, transparent=True)
31+
32+
33+
def adsr():
34+
num_frames = 8000
35+
configs = [
36+
{"attack": 0.2, "hold": 0.2, "decay": 0.2, "sustain": 0.5, "release": 0.2},
37+
{"attack": 0.02, "decay": 0.98, "sustain": 0, "release": 0},
38+
{"attack": 0.01, "hold": 0.3, "decay": 0.05, "sustain": 0.01, "release": 0},
39+
{"attack": 0.98, "decay": 0, "sustain": 1, "release": 0.02},
40+
]
41+
waveforms = [adsr_envelope(**config, num_frames=num_frames) for config in configs]
42+
t = torch.linspace(0, 1.0, num_frames)
43+
44+
fig, axes = plt.subplots(len(configs), 1, sharex=True, sharey=True)
45+
for ax, config, waveform in zip(axes, configs, waveforms):
46+
ax.plot(t, waveform)
47+
ax.grid(True)
48+
ax.set(title=', '.join(f'{k}: {v}' for k, v in config.items()))
49+
fig.tight_layout()
50+
fig.savefig("adsr.png", dpi=200)
51+
52+
# osci()
53+
adsr()
54+
plt.show()

0 commit comments

Comments
 (0)