Tree Attention Torch An implementation of Tree-Attention in PyTorch because it's in JAX for some reason Usage python3 model.py License MIT Todo Implement flash attention from the native official repo, I couldn't because the docs are nowhere to be found and understood