LEDiT: Your Length-Extrapolatable Diffusion Transformer without Positional Encoding
Official PyTorch Implementation
This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring length-extrapolatable diffusion transformer(LEDiT). You can find more visualizations on our project page.
Your Length-Extrapolatable Diffusion Transformer without Positional Encoding
Shen Zhang1, Siyuan Liang1, Yaning Tan2, Zhaowei Chen1, Linze Li1, Ge Wu3, Yuhao Chen1, Shuheng Li1, Zhenyu Zhao1, Caihua Chen2, Jiajun Liang1†, Yao Tang1†
1JIIOV Technology, 2Nanjing University, 3Nankai University
Diffusion transformers (DiTs) struggle to generate images at resolutions higher than their training resolutions. The primary obstacle is that the explicit positional encodings(PE), such as RoPE, need extrapolation which degrades performance when the inference resolution differs from training.
In this paper, we propose a Length-Extrapolatable Diffusion Transformer(LEDiT), a simple yet powerful architecture to overcome this limitation. LEDiT needs no explicit PEs, thereby avoiding extrapolation. The key innovations of LEDiT are introducing causal attention to implicitly impart global positional information to tokens, while enhancing locality to precisely distinguish adjacent tokens. Experiments on 256x256 and 512x512 ImageNet show that LEDiT can scale the inference resolution to 512x512 and 1024x1024, respectively, while achieving better image quality compared to current state-of-the-art length extrapolation methods(NTK-aware, YaRN). Moreover, LEDiT achieves strong extrapolation performance with just 100k steps of fine-tuning on a pretrained DiT, demonstrating its potential for integration into existing text-to-image DiTs.
