Skip to content

Commit

Permalink
Updates model, adds colab, fixes row_log_interval
Browse files Browse the repository at this point in the history
  • Loading branch information
mishushakov committed Dec 13, 2020
1 parent 15ee073 commit e561a74
Show file tree
Hide file tree
Showing 6 changed files with 317 additions and 11 deletions.
10 changes: 0 additions & 10 deletions .pre-commit-config.yaml

This file was deleted.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added models/pedalnet/signal_comparison_e2s_0.0142.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
316 changes: 316 additions & 0 deletions notebooks/colab_GPU_playground.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,316 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "PedalNetRT Playground.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "a4W98Rf771lR",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "73781a99-1ee8-4f9f-92f5-3486f0eba18c"
},
"source": [
"!git clone https://github.com/GuitarML/PedalNetRT.git"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Cloning into 'PedalNetRT'...\n",
"remote: Enumerating objects: 10, done.\u001b[K\n",
"remote: Counting objects: 100% (10/10), done.\u001b[K\n",
"remote: Compressing objects: 100% (9/9), done.\u001b[K\n",
"remote: Total 226 (delta 0), reused 4 (delta 0), pack-reused 216\u001b[K\n",
"Receiving objects: 100% (226/226), 85.04 MiB | 27.51 MiB/s, done.\n",
"Resolving deltas: 100% (98/98), done.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "WOVSyPgGMcYi",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "28907ebb-6e76-4571-b4b6-ae48e4fd4dae"
},
"source": [
"%cd PedalNetRT"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"/content/PedalNetRT\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Vv3lwCpPudVb",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "b9883fe9-6f16-426e-982a-4e525f24dcf7"
},
"source": [
"!pip install -r requirements-colab.txt"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting pytorch_lightning==0.7.3\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/4e/53/0549dd9c44c90e96d217592e094e9c53ef39ae2fed0c5cdb7e57aca65af6/pytorch_lightning-0.7.3-py3-none-any.whl (203kB)\n",
"\u001b[K |████████████████████████████████| 204kB 11.9MB/s \n",
"\u001b[?25hCollecting future>=0.17.1\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/45/0b/38b06fd9b92dc2b68d58b75f900e97884c45bedd2ff83203d933cf5851c9/future-0.18.2.tar.gz (829kB)\n",
"\u001b[K |████████████████████████████████| 829kB 12.7MB/s \n",
"\u001b[?25hRequirement already satisfied: tensorboard>=1.14 in /usr/local/lib/python3.6/dist-packages (from pytorch_lightning==0.7.3) (2.3.0)\n",
"Requirement already satisfied: torch>=1.1 in /usr/local/lib/python3.6/dist-packages (from pytorch_lightning==0.7.3) (1.7.0+cu101)\n",
"Requirement already satisfied: numpy>=1.16.4 in /usr/local/lib/python3.6/dist-packages (from pytorch_lightning==0.7.3) (1.18.5)\n",
"Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.6/dist-packages (from pytorch_lightning==0.7.3) (4.41.1)\n",
"Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch_lightning==0.7.3) (1.17.2)\n",
"Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch_lightning==0.7.3) (0.10.0)\n",
"Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch_lightning==0.7.3) (0.35.1)\n",
"Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch_lightning==0.7.3) (2.23.0)\n",
"Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch_lightning==0.7.3) (1.15.0)\n",
"Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch_lightning==0.7.3) (1.33.2)\n",
"Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch_lightning==0.7.3) (1.7.0)\n",
"Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch_lightning==0.7.3) (50.3.2)\n",
"Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch_lightning==0.7.3) (3.3.3)\n",
"Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch_lightning==0.7.3) (3.12.4)\n",
"Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch_lightning==0.7.3) (1.0.1)\n",
"Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=1.14->pytorch_lightning==0.7.3) (0.4.2)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch>=1.1->pytorch_lightning==0.7.3) (3.7.4.3)\n",
"Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch>=1.1->pytorch_lightning==0.7.3) (0.8)\n",
"Requirement already satisfied: rsa<5,>=3.1.4; python_version >= \"3\" in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch_lightning==0.7.3) (4.6)\n",
"Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch_lightning==0.7.3) (4.1.1)\n",
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch_lightning==0.7.3) (0.2.8)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch_lightning==0.7.3) (3.0.4)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch_lightning==0.7.3) (1.24.3)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch_lightning==0.7.3) (2.10)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.14->pytorch_lightning==0.7.3) (2020.11.8)\n",
"Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tensorboard>=1.14->pytorch_lightning==0.7.3) (2.0.0)\n",
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.14->pytorch_lightning==0.7.3) (1.3.0)\n",
"Requirement already satisfied: pyasn1>=0.1.3 in /usr/local/lib/python3.6/dist-packages (from rsa<5,>=3.1.4; python_version >= \"3\"->google-auth<2,>=1.6.3->tensorboard>=1.14->pytorch_lightning==0.7.3) (0.4.8)\n",
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard>=1.14->pytorch_lightning==0.7.3) (3.4.0)\n",
"Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=1.14->pytorch_lightning==0.7.3) (3.1.0)\n",
"Building wheels for collected packages: future\n",
" Building wheel for future (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for future: filename=future-0.18.2-cp36-none-any.whl size=491057 sha256=745be7372a3c702a9eb466c1cb82625633700e071d9cd422b6f7da25da5d8cdf\n",
" Stored in directory: /root/.cache/pip/wheels/8b/99/a0/81daf51dcd359a9377b110a8a886b3895921802d2fc1b2397e\n",
"Successfully built future\n",
"Installing collected packages: future, pytorch-lightning\n",
" Found existing installation: future 0.16.0\n",
" Uninstalling future-0.16.0:\n",
" Successfully uninstalled future-0.16.0\n",
"Successfully installed future-0.18.2 pytorch-lightning-0.7.3\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "n9vktclZBGuc"
},
"source": [
"!python3 prepare.py \"data/ts9_in.wav\" \"data/ts9_out.wav\""
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "XBb2THG2BtLm",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "17f2b4ec-51ba-487d-f745-9786de93df64"
},
"source": [
"!python3 train.py"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"\rEpoch 1484: 75% 18/24 [00:02<00:00, 7.85it/s, loss=0.026, v_num=2]\n",
"\rValidating: 0% 0/6 [00:00<?, ?it/s]\u001b[A\n",
"Epoch 1484: 83% 20/24 [00:02<00:00, 8.16it/s, loss=0.026, v_num=2]\n",
"Epoch 1484: 100% 24/24 [00:02<00:00, 8.97it/s, loss=0.026, v_num=2]\n",
"Epoch 1485: 75% 18/24 [00:02<00:00, 7.84it/s, loss=0.026, v_num=2]\n",
"Validating: 0% 0/6 [00:00<?, ?it/s]\u001b[A\n",
"Epoch 1485: 83% 20/24 [00:02<00:00, 8.12it/s, loss=0.026, v_num=2]\n",
"Epoch 1485: 100% 24/24 [00:02<00:00, 9.02it/s, loss=0.026, v_num=2]\n",
"Epoch 1486: 75% 18/24 [00:02<00:00, 7.80it/s, loss=0.029, v_num=2]\n",
"Validating: 0% 0/6 [00:00<?, ?it/s]\u001b[A\n",
"Epoch 1486: 83% 20/24 [00:02<00:00, 8.04it/s, loss=0.029, v_num=2]\n",
"Epoch 1486: 100% 24/24 [00:02<00:00, 8.88it/s, loss=0.029, v_num=2]\n",
"Epoch 1487: 75% 18/24 [00:02<00:00, 7.84it/s, loss=0.040, v_num=2]\n",
"Validating: 0% 0/6 [00:00<?, ?it/s]\u001b[A\n",
"Epoch 1487: 83% 20/24 [00:02<00:00, 8.04it/s, loss=0.040, v_num=2]\n",
"Epoch 1487: 100% 24/24 [00:02<00:00, 8.91it/s, loss=0.040, v_num=2]\n",
"Epoch 1488: 75% 18/24 [00:02<00:00, 7.81it/s, loss=0.032, v_num=2]\n",
"Validating: 0% 0/6 [00:00<?, ?it/s]\u001b[A\n",
"Epoch 1488: 83% 20/24 [00:02<00:00, 8.05it/s, loss=0.032, v_num=2]\n",
"Epoch 1488: 100% 24/24 [00:02<00:00, 8.92it/s, loss=0.032, v_num=2]\n",
"Epoch 1489: 75% 18/24 [00:02<00:00, 7.73it/s, loss=0.031, v_num=2]\n",
"Validating: 0% 0/6 [00:00<?, ?it/s]\u001b[A\n",
"Epoch 1489: 83% 20/24 [00:02<00:00, 7.98it/s, loss=0.031, v_num=2]\n",
"Epoch 1489: 100% 24/24 [00:02<00:00, 8.86it/s, loss=0.031, v_num=2]\n",
"Epoch 1490: 75% 18/24 [00:02<00:00, 7.75it/s, loss=0.035, v_num=2]\n",
"Validating: 0% 0/6 [00:00<?, ?it/s]\u001b[A\n",
"Epoch 1490: 83% 20/24 [00:02<00:00, 8.01it/s, loss=0.035, v_num=2]\n",
"Epoch 1490: 100% 24/24 [00:02<00:00, 8.91it/s, loss=0.035, v_num=2]\n",
"Epoch 1491: 75% 18/24 [00:02<00:00, 7.75it/s, loss=0.026, v_num=2]\n",
"Validating: 0% 0/6 [00:00<?, ?it/s]\u001b[A\n",
"Epoch 1491: 83% 20/24 [00:02<00:00, 8.01it/s, loss=0.026, v_num=2]\n",
"Epoch 1491: 100% 24/24 [00:02<00:00, 8.88it/s, loss=0.026, v_num=2]\n",
"Epoch 1492: 75% 18/24 [00:02<00:00, 7.86it/s, loss=0.026, v_num=2]\n",
"Validating: 0% 0/6 [00:00<?, ?it/s]\u001b[A\n",
"Epoch 1492: 83% 20/24 [00:02<00:00, 8.07it/s, loss=0.026, v_num=2]\n",
"Epoch 1492: 100% 24/24 [00:02<00:00, 8.88it/s, loss=0.026, v_num=2]\n",
"Epoch 1493: 75% 18/24 [00:02<00:00, 7.88it/s, loss=0.026, v_num=2]\n",
"Validating: 0% 0/6 [00:00<?, ?it/s]\u001b[A\n",
"Epoch 1493: 83% 20/24 [00:02<00:00, 8.09it/s, loss=0.026, v_num=2]\n",
"Epoch 1493: 100% 24/24 [00:02<00:00, 8.96it/s, loss=0.026, v_num=2]\n",
"Epoch 1494: 75% 18/24 [00:02<00:00, 7.89it/s, loss=0.026, v_num=2]\n",
"Validating: 0% 0/6 [00:00<?, ?it/s]\u001b[A\n",
"Epoch 1494: 83% 20/24 [00:02<00:00, 8.17it/s, loss=0.026, v_num=2]\n",
"Epoch 1494: 100% 24/24 [00:02<00:00, 9.02it/s, loss=0.026, v_num=2]\n",
"Epoch 1495: 75% 18/24 [00:02<00:00, 7.74it/s, loss=0.026, v_num=2]\n",
"Validating: 0% 0/6 [00:00<?, ?it/s]\u001b[A\n",
"Epoch 1495: 83% 20/24 [00:02<00:00, 7.97it/s, loss=0.026, v_num=2]\n",
"Epoch 1495: 100% 24/24 [00:02<00:00, 8.82it/s, loss=0.026, v_num=2]\n",
"Epoch 1496: 75% 18/24 [00:02<00:00, 7.85it/s, loss=0.026, v_num=2]\n",
"Validating: 0% 0/6 [00:00<?, ?it/s]\u001b[A\n",
"Epoch 1496: 83% 20/24 [00:02<00:00, 8.10it/s, loss=0.026, v_num=2]\n",
"Epoch 1496: 100% 24/24 [00:02<00:00, 8.91it/s, loss=0.026, v_num=2]\n",
"Epoch 1497: 75% 18/24 [00:02<00:00, 7.87it/s, loss=0.025, v_num=2]\n",
"Validating: 0% 0/6 [00:00<?, ?it/s]\u001b[A\n",
"Epoch 1497: 83% 20/24 [00:02<00:00, 8.05it/s, loss=0.025, v_num=2]\n",
"Epoch 1497: 100% 24/24 [00:02<00:00, 8.90it/s, loss=0.025, v_num=2]\n",
"Epoch 1498: 75% 18/24 [00:02<00:00, 7.78it/s, loss=0.025, v_num=2]\n",
"Validating: 0% 0/6 [00:00<?, ?it/s]\u001b[A\n",
"Epoch 1498: 83% 20/24 [00:02<00:00, 8.08it/s, loss=0.025, v_num=2]\n",
"Epoch 1498: 100% 24/24 [00:02<00:00, 8.91it/s, loss=0.025, v_num=2]\n",
"Epoch 1499: 75% 18/24 [00:02<00:00, 7.82it/s, loss=0.027, v_num=2]\n",
"Validating: 0% 0/6 [00:00<?, ?it/s]\u001b[A\n",
"Epoch 1499: 83% 20/24 [00:02<00:00, 8.11it/s, loss=0.027, v_num=2]\n",
"Epoch 1499: 100% 24/24 [00:02<00:00, 8.92it/s, loss=0.027, v_num=2]\n",
"Epoch 1500: 75% 18/24 [00:02<00:00, 7.70it/s, loss=0.028, v_num=2]\n",
"Validating: 0% 0/6 [00:00<?, ?it/s]\u001b[A\n",
"Epoch 1500: 83% 20/24 [00:02<00:00, 7.95it/s, loss=0.028, v_num=2]\n",
"Epoch 1500: 100% 24/24 [00:02<00:00, 8.78it/s, loss=0.028, v_num=2]\n",
"Epoch 1500: 100% 24/24 [00:02<00:00, 8.78it/s, loss=0.028, v_num=2]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "KZ6wG3E0rmQO"
},
"source": [
"!python3 test.py"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "m_KwfnpXryEY",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "805f9260-0358-4906-85b2-2b6114ed2656"
},
"source": [
"!python3 plot.py"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Error to signal (with pre-emphasis filter): 0.014186971\n",
"Error to signal (no pre-emphasis filter): 0.014186973\n",
"Creating spectrogram data..\n",
"plot.py:127: RuntimeWarning: divide by zero encountered in log10\n",
" plt.pcolormesh(times, frequencies, 10 * np.log10(spectrogram))\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "vCJoHzyKDxBl"
},
"source": [
"!python3 export.py"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "oIEkvTqVYKAX",
"outputId": "f7bf2f8a-9cd4-4e74-ba82-cea81487f3b0"
},
"source": [
"!python predict.py data/ts9_in.wav models/pedalnet/predict.wav"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"\r 0% 0/7 [00:00<?, ?it/s]tcmalloc: large alloc 1507164160 bytes == 0xd8862000 @ 0x7f2509ad0b6b 0x7f2509af0379 0x7f24ba9fd74e 0x7f24ba9ff7b6 0x7f24f5469d53 0x7f24f4de454a 0x7f24f513ec0a 0x7f24f5166803 0x7f24f52ecb14 0x7f24f54294ee 0x7f24f4e82976 0x7f24f4e83b30 0x7f24f5140b09 0x7f24f49bf249 0x7f24f52d9ae8 0x7f24f51e58a5 0x7f24f4e8541b 0x7f24f53757d8 0x7f24f49bf249 0x7f24f52d9ae8 0x7f24f51e59f5 0x7f24f67b9997 0x7f24f49bf249 0x7f24f52d9ae8 0x7f24f51e59f5 0x7f2504fff30e 0x50a4a5 0x50cc96 0x507be4 0x508ec2 0x594a01\n",
"tcmalloc: large alloc 1507164160 bytes == 0x1325ba000 @ 0x7f2509ad0b6b 0x7f2509af0379 0x7f24ba9fd74e 0x7f24ba9ff7b6 0x7f24f50c0271 0x7f24f50b0489 0x7f24f50b0e30 0x7f24f50babda 0x7f24f50beb57 0x7f24f50b7cf6 0x7f24f50b877f 0x7f24f537d5b6 0x7f24f53c179c 0x7f24f4a35e53 0x7f24f52de4bb 0x7f24f51eebbb 0x7f24f6733e0d 0x7f24f4a35e53 0x7f24f52de4bb 0x7f24f51eebbb 0x7f24f4bbed00 0x7f24f5376803 0x7f24f53c0be3 0x7f24f4a35dba 0x7f24f52dca1c 0x7f24f51ed4c4 0x7f24f4bb5e16 0x7f24f537650d 0x7f24f53c0aa3 0x7f24f4a35f0f 0x7f24f52dbdc5\n",
" 14% 1/7 [00:11<01:09, 11.57s/it]tcmalloc: large alloc 1507164160 bytes == 0xab372000 @ 0x7f2509ad0b6b 0x7f2509af0379 0x7f24ba9fd74e 0x7f24ba9ff7b6 0x7f24f5469d53 0x7f24f4de454a 0x7f24f513ec0a 0x7f24f5166803 0x7f24f52ecb14 0x7f24f54294ee 0x7f24f4e82976 0x7f24f4e83b30 0x7f24f5140b09 0x7f24f49bf249 0x7f24f52d9ae8 0x7f24f51e58a5 0x7f24f4e8541b 0x7f24f53757d8 0x7f24f49bf249 0x7f24f52d9ae8 0x7f24f51e59f5 0x7f24f67b9997 0x7f24f49bf249 0x7f24f52d9ae8 0x7f24f51e59f5 0x7f2504fff30e 0x50a4a5 0x50cc96 0x507be4 0x508ec2 0x594a01\n",
"100% 7/7 [01:07<00:00, 9.67s/it]\n"
],
"name": "stdout"
}
]
}
]
}
1 change: 1 addition & 0 deletions requirements-colab.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytorch_lightning==1.0.8
1 change: 0 additions & 1 deletion requirements-dev.txt

This file was deleted.

0 comments on commit e561a74

Please sign in to comment.