Sicheng Mo1, Thao Nguyen2, Richard Zhang3, Nicholas Kolkin3, Siddharth Srinivasan Iyer3, Eli Shechtman3, Krishna Kumar Singh3, Yong Jae Lee2, Bolei Zhou1, Yuheng Li31UCLA 2UW-Madison 3Adobe
This is the official implementation of Group Diffusion, a Generative AI algorithm for enhancing image generation via cross-sample attention.
Download the code: (example, please update it)
git clone https://github.com/adobe-research/GroupDiff.git
cd GroupDiffCreate and activate conda environment:
conda create -n gdiff python=3.10 -y && conda activate gdiff
pip install -r requirements.txt
pip install 'tensorflow[and-cuda]'Download the ImageNet dataset, and place it in under data/imagenet/.
After that, run the following script to extract VAE latent and embedding for each images.
mkdir data
# Download / link imagenet train to data/imagenet
ln -s YOUR_IMAGENET_PATH data/imagenet
torchrun --nnodes=1 --nproc_per_node=8 dataset/extract_feats.py \
--data-path data/imagenet/imagenet/train \
--features-path data/imagenet_feats/train \
--models allThen, create FAISS index for fast similarity search during training (this will take around 30 minutes):
python dataset/create_fasiss.py \
--metadata data/imagenet_feats/train/metadata.json \
--feature_key dinov2-l \
--output_dir data/imagenet_feats/train_index/dinov2-l \
--index_type ivfpq \
--num_workers 32This will create a FAISS index from the extracted DINOv2-L features, which enables efficient nearest neighbor search during GroupDiff training. The ivfpq index type provides a good balance between search speed and memory usage.
| Model | Resume From | Pre-Training Iters | Fine-tuning Iters | Link |
|---|---|---|---|---|
| GroupDiff-l-4 (SiT-XL) | REPA-SiT | 4M | 500K | ckpt |
| GroupDiff-l-4 (DiT-XL) | REPA-SiT | 4M | 500K | ckpt |
Run the following script to download our pre-trained checkpoints and FID stats file (from ADM's TensorFlow evaluation suite).
huggingface-cli download Sichengmo/GroupDiff --local-dir released_model
wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz -O data/VIRTUAL_imagenet256_labeled.npz
Detailed scripts for GroupDiff training can be found in scripts/.
Train GroupDiff-l with DiT-XL (800 epochs):
project=GroupDiff
exp_name=groupdiff-l-4-dit-xl-dinov2
batch_size=32 # per GPU batch size, global batch size = batch_size x num_gpus = 32 x 8 = 256
epochs=800
YOUR_WANDB_ENTITY="YourWandbEntity"
accelerate launch --num_processes 8 --multi_gpu --mixed_precision=bf16 train.py \
--project $project --exp_name $exp_name --auto_resume \
--model DiT_xl --patch_size 2 --num_max_sample 4 \
--batch_size $batch_size --epochs $epochs \
--lr 1e-4 --num_sampling_steps 250 \
--data_path data/imagenet/train \
--query_sim feat --use_cached_tokens \
--entity $YOUR_WANDB_ENTITYOptional: Using Custom Feature Paths
Then add these arguments to specify custom paths:
--metadata_path data/imagenet_feats/train/metadata.json \
--faiss_index_path data/imagenet_feats/train_index/dinov2-l/imagenet_index_ivfpq.faiss \
--faiss_feature_name "dinov2-l" \
--features_root data/imagenet_feats/train \
--load_latent \
--latent_feature_name "vae-256" \
--latent_root data/imagenet_feats/trainTrain GroupDiff-l with SiT-XL (800 epochs):
project=GroupDiff
exp_name=groupdiff-l-4-sit-xl-dinov2
batch_size=32 # per GPU batch size, global batch size = batch_size x num_gpus = 32 x 8 = 256
epochs=800
YOUR_WANDB_ENTITY="YourWandbEntity"
accelerate launch --num_processes 8 --multi_gpu --mixed_precision=bf16 train.py \
--project $project --exp_name $exp_name --auto_resume \
--model SiT_xl --patch_size 2 --num_max_sample 4 \
--batch_size $batch_size --epochs $epochs \
--lr 1e-4 --num_sampling_steps 250 \
--data_path data/imagenet/train \
--query_sim feat --use_cached_tokens \
--entity $YOUR_WANDB_ENTITYcheckpoint_path=work_dirs/${project}/${exp_name}/checkpoints/lastest.pth
accelerate launch --num_processes 8 --multi_gpu --mixed_precision=bf16 eval.py \
--project $project --exp_name $exp_name --auto_resume \
--model DiT_base --patch_size 2 --num_max_sample 4 \
--batch_size $batch_size --eval_bsz 128 \
--num_sampling_steps 250 --cfg 2.5 \
--guidance_low 0.0 --guidance_high 1.0 \
--cond_group_size 1 --uncond_group_size 4 \
--num_images 50000 --seed 0 \
--load_from ${checkpoint_path} --use_ema \
--fid_stats_path data/VIRTUAL_imagenet256_labeled.npz \
--entity $YOUR_WANDB_ENTITYproject=GroupDiff-l-pretrained
exp_name=gdiff-l-4-dit-xl-2-resume
checkpoint_path=released_model/${exp_name}.pth
batch_size=32 # per GPU batch size
YOUR_WANDB_ENTITY="YourWandbEntity"
accelerate launch --num_processes 8 --multi_gpu --mixed_precision=bf16 eval.py \
--project $project --exp_name $exp_name --auto_resume \
--model DiT_xl --patch_size 2 --num_max_sample 4 \
--batch_size $batch_size --eval_bsz 128 \
--num_sampling_steps 250 --cfg 1.65 \
--guidance_low 0.0 --guidance_high 1.0 \
--cond_group_size 1 --uncond_group_size 4 \
--num_images 50000 --seed 0 \
--load_from ${checkpoint_path} --use_ema \
--fid_stats_path data/VIRTUAL_imagenet256_labeled.npz \
--entity $YOUR_WANDB_ENTITYproject=GroupDiff-l-pretrained
exp_name=gdiff-l-4-sit-xl-2-repa-resume
checkpoint_path=released_model/${exp_name}.pth
batch_size=32 # per GPU batch size
YOUR_WANDB_ENTITY="YourWandbEntity"
accelerate launch --num_processes 8 --multi_gpu --mixed_precision=bf16 eval.py \
--project $project --exp_name $exp_name --auto_resume \
--model SiT_xl --patch_size 2 --num_max_sample 4 \
--batch_size $batch_size --eval_bsz 128 \
--num_sampling_steps 250 --cfg 2.585 \
--guidance_low 0.25 --guidance_high 0.75 \
--cond_group_size 1 --uncond_group_size 4 \
--num_images 50000 --seed 0 \
--load_from ${checkpoint_path} --use_ema \
--fid_stats_path data/VIRTUAL_imagenet256_labeled.npz \
--entity $YOUR_WANDB_ENTITYWe thank the authors of DiT, and SiT for their foundational work.
Our codebase builds upon several excellent open-source projects, including DeTok and MAR. We are grateful to the communities behind them.
This codebase has been cleaned up but has not undergone extensive testing. If you encounter any issues or have questions, please open a GitHub issue. We appreciate your feedback!
