Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs. They beat the previous first place by a large margin. The technique uses axial attention across frequency (hence multi-band) and time. They also have experiments to show that rotary positional encoding led to a huge improvement over learned absolute positions.
It also includes support for stereo training and outputting multiple stems.
Please join if you are interested in replicating a SOTA music source separator out in the open
Update: This paper has been replicated by Roman and weight open sourced here
Update 2: Used for this Katy Perry remix!
Update 3: Kimberley Jensen has open sourced a MelBand Roformer trained on vocals here!
-
StabilityAI and 🤗 Huggingface for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.
-
Roee and Fabian-Robert for sharing their audio expertise and fixing audio hyperparameters
-
@chenht2010 and Roman for working out the default band splitting hyperparameter!
-
Max Prod for reporting a big bug with Mel-Band Roformer with stereo training!
-
Roman for successfully training the model and open sourcing his training code and weights at this repository!
-
Christopher for fixing an issue with multiple stems in Mel-Band Roformer
-
Iver Jordal for identifying that the default stft window function is not correct
$ pip install BS-RoFormer
import torch
from bs_roformer import BSRoformer
model = BSRoformer(
dim = 512,
depth = 12,
time_transformer_depth = 1,
freq_transformer_depth = 1
)
x = torch.randn(2, 352800)
target = torch.randn(2, 352800)
loss = model(x, target = target)
loss.backward()
# after much training
out = model(x)
To use the Mel-Band Roformer proposed in a recent follow up paper, simply import MelBandRoformer
instead
import torch
from bs_roformer import MelBandRoformer
model = MelBandRoformer(
dim = 32,
depth = 1,
time_transformer_depth = 1,
freq_transformer_depth = 1
)
x = torch.randn(2, 352800)
target = torch.randn(2, 352800)
loss = model(x, target = target)
loss.backward()
# after much training
out = model(x)
- get the multiscale stft loss in there
- figure out what
n_fft
should be - review band split + mask estimation modules
@inproceedings{Lu2023MusicSS,
title = {Music Source Separation with Band-Split RoPE Transformer},
author = {Wei-Tsung Lu and Ju-Chiang Wang and Qiuqiang Kong and Yun-Ning Hung},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:261556702}
}
@inproceedings{Wang2023MelBandRF,
title = {Mel-Band RoFormer for Music Source Separation},
author = {Ju-Chiang Wang and Wei-Tsung Lu and Minz Won},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:263608675}
}
@misc{ho2019axial,
title = {Axial Attention in Multidimensional Transformers},
author = {Jonathan Ho and Nal Kalchbrenner and Dirk Weissenborn and Tim Salimans},
year = {2019},
archivePrefix = {arXiv}
}
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
@article{Bondarenko2023QuantizableTR,
title = {Quantizable Transformers: Removing Outliers by Helping Attention Heads Do Nothing},
author = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
journal = {ArXiv},
year = {2023},
volume = {abs/2306.12929},
url = {https://api.semanticscholar.org/CorpusID:259224568}
}
@inproceedings{ElNouby2021XCiTCI,
title = {XCiT: Cross-Covariance Image Transformers},
author = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou},
booktitle = {Neural Information Processing Systems},
year = {2021},
url = {https://api.semanticscholar.org/CorpusID:235458262}
}