|
| 1 | +import torch |
| 2 | +from torchaudio.prototype.functional import oscillator_bank, adsr_envelope |
| 3 | + |
| 4 | +import matplotlib.pyplot as plt |
| 5 | + |
| 6 | + |
| 7 | +def osci(): |
| 8 | + sample_rate = 8000 |
| 9 | + num_frames = 8000 * 3 |
| 10 | + |
| 11 | + freq = torch.ones((num_frames, 1)) * torch.tensor([[2000]]) # , 3000]]) |
| 12 | + amps = torch.ones_like(freq) |
| 13 | + waveform32 = oscillator_bank(freq, amps, sample_rate) |
| 14 | + |
| 15 | + freq = freq.to(torch.float64) |
| 16 | + amps = freq.to(torch.float64) |
| 17 | + waveform64 = oscillator_bank(freq, amps, sample_rate) |
| 18 | + |
| 19 | + fig, axes = plt.subplots(2, 1, sharex=True, sharey=True) |
| 20 | + fig.suptitle("Precision and waveform generated by oscillator_bank") |
| 21 | + _, _, _, cax = axes[0].specgram(waveform32, Fs=sample_rate) |
| 22 | + axes[0].set(title="float32", ylabel="Frequency [Hz]") |
| 23 | + fig.colorbar(cax) |
| 24 | + |
| 25 | + _, _, _, cax = axes[1].specgram(waveform64, Fs=sample_rate) |
| 26 | + axes[1].set(title="float64", xlabel="time [s]", ylabel="Frequency [Hz]") |
| 27 | + plt.colorbar(cax) |
| 28 | + plt.tight_layout() |
| 29 | + plt.subplots_adjust(left=0.2, top=0.85, right=0.9,bottom=0.15) |
| 30 | + fig.savefig("oscillator_precision.png", dpi=200, transparent=True) |
| 31 | + |
| 32 | + |
| 33 | +def adsr(): |
| 34 | + num_frames = 8000 |
| 35 | + configs = [ |
| 36 | + {"attack": 0.2, "hold": 0.2, "decay": 0.2, "sustain": 0.5, "release": 0.2}, |
| 37 | + {"attack": 0.02, "decay": 0.98, "sustain": 0, "release": 0}, |
| 38 | + {"attack": 0.01, "hold": 0.3, "decay": 0.05, "sustain": 0.01, "release": 0}, |
| 39 | + {"attack": 0.98, "decay": 0, "sustain": 1, "release": 0.02}, |
| 40 | + ] |
| 41 | + waveforms = [adsr_envelope(**config, num_frames=num_frames) for config in configs] |
| 42 | + t = torch.linspace(0, 1.0, num_frames) |
| 43 | + |
| 44 | + fig, axes = plt.subplots(len(configs), 1, sharex=True, sharey=True) |
| 45 | + for ax, config, waveform in zip(axes, configs, waveforms): |
| 46 | + ax.plot(t, waveform) |
| 47 | + ax.grid(True) |
| 48 | + ax.set(title=', '.join(f'{k}: {v}' for k, v in config.items())) |
| 49 | + fig.tight_layout() |
| 50 | + fig.savefig("adsr.png", dpi=200) |
| 51 | + |
| 52 | +# osci() |
| 53 | +adsr() |
| 54 | +plt.show() |
0 commit comments