From: hu-po
A recent paper, “Scalable Diffusion Models with Transformers,” introduces a new class of diffusion models called Diffusion Transformers (DiTs) [00:14:12]. Published on December 19th, this research from UC Berkeley and New York University, with significant compute budget provided by Meta AI’s FAIR team, showcases state-of-the-art image quality [00:03:09].
The project has a GitHub repository with code, PyTorch models, and pre-trained weights, although it operates under a CC BY-NC license, meaning it cannot be used for commercial purposes [00:22:31].
Key Innovation: Replacing U-Nets with Transformers
Historically, diffusion models have predominantly adopted the U-Net architecture as their de facto choice [00:07:34]. This paper demonstrates that the U-Net bias is not crucial for the performance of diffusion models and can be effectively replaced with standard designs like Transformers [00:09:09].
The core idea is to train latent diffusion models where the commonly used U-Net backbone is swapped out for a Transformer that operates on latent patches [00:05:08]. This shift allows diffusion models to benefit from the recent trend of architecture unification seen across NLP and vision [00:09:21].
Diffusion Transformer (DiT) Architecture
DiTs adhere to the best practices of Vision Transformers (ViTs) [00:09:48]. Instead of convolving across an image like a traditional convolutional neural network (CNN), ViTs take an image, cut it into small patches, and feed each patch as a token into the Transformer [00:10:08].
The input to a DiT is a spatial latent representation [00:28:56]. For instance, a 256x256x3 image is converted into a lower-dimensional latent embedding, such as 32x32x4 [00:29:17]. This makes the Transformers operate on a smaller piece of data, improving efficiency [00:29:23].
DiTs use frequency-based positional embeddings, which provide spatial information for each patch [00:29:42]. The number of tokens (T) created from patches is determined by a patch size hyperparameter (P), with the paper experimenting with P values of 2, 4, and 8 [00:31:52]. Smaller patch sizes lead to more tokens and increased computational complexity but also better detail [00:56:01].
Conditioning and Normalization
DiTs are conditioned on:
- Noise time step: Information about the current step in the diffusion process [00:32:55].
- Class labels: E.g., ‘dog’ or ‘cat’ for ImageNet [00:32:52].
- Natural language prompts: Encoded text prompts, often using models like CLIP [00:33:03].
These conditioning elements are introduced as additional tokens, processed separately from the image tokens, with a cross-attention layer linking them to the main image processing [00:33:41]. This cross-attention layer adds only about 15% overhead [00:34:01].
The models also heavily utilize adaptive normalization layers, such as adaptive layer normalization (AdaLN), which incorporate scale and shift parameters [00:34:09]. These normalization techniques ensure activations remain well-distributed, aiding gradient flow and learning [00:35:15]. They are applied within Transformer blocks and prior to any residual connections [00:37:05].
Performance and Scalability
The paper analyzes the scalability of DiTs through their forward pass complexity, measured in Gigaflops [00:05:26]. Key findings include:
- Increased Depth and Width: DiTs with higher Gigaflops (achieved through increased Transformer depth, width, or number of input tokens) consistently result in lower FID (Fréchet Inception Distance) scores, indicating better image quality [00:05:42]. This confirms the well-known machine learning principle: bigger models and more data lead to better results [00:06:03].
- State-of-the-Art Results: The largest model, DiT-XL/2 (meaning DiT XL with a patch factor of 2), achieved a state-of-the-art FID of 2.27 on the 256x256 ImageNet generation benchmark [00:11:32].
- Compute Efficiency: DiT-XL/2 outperforms all prior U-Net based diffusion models (e.g., LDM-A, LDM-4G, ADM-G), even models with higher Gigaflops, demonstrating superior compute efficiency [00:12:45]. This suggests that while scaling is powerful, a better architecture can yield superior results even at a smaller scale [00:13:38].
Different DiT model sizes (Small, Medium, Large, XL) are defined by their number of Transformer layers (N), hidden dimension size (D), and multi-attention heads. For example, DiT-S has 12 layers, 6 heads, and a hidden dimension of 384, while DiT-XL has 28 layers, 16 heads, and a hidden dimension of 1152 [00:42:01].
Training and Implementation Details
Loss Function and Guidance
The models are trained using a simple mean squared error between the predicted noise and the ground truth sampled Gaussian noise [00:24:09]. They also leverage classifier-free guidance, where conditioning labels (like class labels) are randomly dropped out during training [00:27:01].
Hyperparameters and Initialization
DiTs use a constant learning rate of 1e-4 with no weight decay during initial training phases, though an exponential moving average of DiT weights is maintained [00:45:05]. Many hyperparameters are retained from prior work (ADM) and were not extensively tuned, suggesting potential for further optimization [00:48:18]. The final layer is initialized with zeros, and standard ViT weight initialization techniques are used otherwise [00:44:55].
Hardware and Cost
Models were implemented in JAX and trained on Google TPU v3 pods [00:51:19]. The largest model, DiT-XL/2, achieved 5.7 iterations per second on a TPU v3-256 pod [00:52:46].
To put this into perspective:
- Training for 400,000 steps at 5.7 iterations/second takes approximately 19 hours [00:54:01].
- Estimating TPU v3 pricing at 10,000 per model [00:54:46].
- Training 12 different DiT models for the reported experiments would therefore cost over $100,000 [00:55:20].
Sampling Steps
While the paper reports results primarily with 250 sampling steps, experiments show that increasing sampling steps generally improves quality, though with diminishing returns beyond a certain point [01:03:01].
Generated Image Quality and Future Outlook
The DiTs produce images with exceptional crispness and semantic correctness, capturing subtle details like the wrist of a dog or wet fur on an otter [00:03:52] [01:12:14]. However, like many generative models, they can still exhibit inconsistencies at a higher semantic level, such as illogical connections in complex objects (e.g., a boat’s mast appearing disconnected) [01:10:40].
The researchers note that additional model compute is a critical ingredient for improved DiT models, and larger DiT models are more compute-efficient [01:00:22]. The paper concludes by suggesting future work should continue to scale DiTs to larger models and token counts [01:01:10].
The rapid progress in diffusion models, compared to earlier generative models like GANs, suggests that the speed of improvement is accelerating [01:14:21]. This could lead to capabilities like generating 4K video from text prompts by the end of 2023, potentially transforming content creation [01:15:02].