[go: up one dir, main page]

Skip to content

PyTorch Implementation of "Scalable Diffusion Models with Transformers" with author's tweaks

License

Notifications You must be signed in to change notification settings

neonsecret/DiTFusion

 
 

Repository files navigation

Scalable Diffusion Models with Transformers (DiT)
Official PyTorch Implementation

Paper | Project Page | Run DiT-XL/2 Hugging Face Spaces Open In Colab

DiT samples

This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring diffusion models with transformers (DiTs). You can find more visualizations on our project page.

Scalable Diffusion Models with Transformers
William Peebles, Saining Xie
UC Berkeley, New York University

We train latent diffusion models, replacing the commonly-used U-Net backbone with a transformer that operates on latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass complexity as measured by Gflops. We find that DiTs with higher Gflops---through increased transformer depth/width or increased number of input tokens---consistently have lower FID. In addition to good scalability properties, our DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512×512 and 256×256 benchmarks, achieving a state-of-the-art FID of 2.27 on the latter.

This repository contains:

An implementation of DiT directly in Hugging Face diffusers can also be found here.

Setup

First, download and set up the repo:

git clone https://github.com/facebookresearch/DiT.git
cd DiT

We provide an environment.yml file that can be used to create a Conda environment. If you only want to run pre-trained models locally on CPU, you can remove the cudatoolkit and pytorch-cuda requirements from the file.

conda env create -f environment.yml
conda activate DiT

Sampling Hugging Face Spaces Open In Colab

More DiT samples

Pre-trained DiT checkpoints. You can sample from our pre-trained DiT models with sample.py. Weights for our pre-trained DiT model will be automatically downloaded depending on the model you use. The script has various arguments to switch between the 256x256 and 512x512 models, adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from our 512x512 DiT-clipped model, you can use the new gradio interface:

python sample_gradio.py --ckpt pretrained_models/last.ckpt

For convenience, our pre-trained DiT models can be downloaded directly here as well:

DiT Model Image Resolution
DiT_clipped 256x256

Training DiT

We provide a training script for DiT in train_pl.py. This script can be used to train class-conditional DiT models, but it can be easily modified to support other types of conditioning. To launch DiT-clipped (256x256) training with N GPUs on one node:

python train_pl.py --coco_dataset_path (...)/datasets/fast-ai-coco

Enhancements

Improvements to the project could be as follows:

  • Improve generation quality by training the checkpoint further
  • Adding more DiT_clipped architectures with more params and better training them

BibTeX

@article{Peebles2022DiT,
  title={Scalable Diffusion Models with Transformers},
  author={William Peebles and Saining Xie},
  year={2022},
  journal={arXiv preprint arXiv:2212.09748},
}

Acknowledgments

We thank Kaiming He, Ronghang Hu, Alexander Berg, Shoubhik Debnath, Tim Brooks, Ilija Radosavovic and Tete Xiao for helpful discussions. William Peebles is supported by the NSF Graduate Research Fellowship.

This codebase borrows from OpenAI's diffusion repos, most notably ADM.

License

The code and model weights are licensed under CC-BY-NC. See LICENSE.txt for details.

About

PyTorch Implementation of "Scalable Diffusion Models with Transformers" with author's tweaks

Resources

License

Code of conduct

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 97.5%
  • Jupyter Notebook 2.5%