This is my take on the official PyTorch FSDP2 example found here:
https://github.com/pytorch/examples/tree/main/distributed/FSDP2
The original example has some lint I wanted to clean up. This project will serve as my template for creating other trainers that take advantage of FSDP2.
I use AMD GPUs. The requirements.txt will install the ROCm specific PyTorch from AMD's wheels.
Install PyTorch for ROCm
- PyTorch 2.6
- ROCm 6.4.1
What does it do when you run it?
It creates a useless model out of thin air (random inputs).
(.venv) mark@wide:~/prog/fsdp2-minimal-rocm$ ./dotrain.sh
FSDPTransformer(
(tok_embeddings): Embedding(1024, 16)
(pos_embeddings): Embedding(64, 16)
(dropout): Dropout(p=0, inplace=False)
(layers): ModuleList(
(0-9): 10 x FSDPTransformerBlock(
(attention_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
(attention): Attention(
(resid_dropout): Dropout(p=0, inplace=False)
(wq): Linear(in_features=16, out_features=16, bias=False)
(wk): Linear(in_features=16, out_features=16, bias=False)
(wv): Linear(in_features=16, out_features=16, bias=False)
(wo): Linear(in_features=16, out_features=16, bias=False)
)
(ffn_norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
(feed_forward): FeedForward(
(w1): Linear(in_features=16, out_features=64, bias=True)
(gelu): GELU(approximate='none')
(w2): Linear(in_features=64, out_features=16, bias=True)
(resid_dropout): Dropout(p=0, inplace=False)
)
)
)
(norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
(output): Linear(in_features=16, out_features=1024, bias=False)
)
... then a bunch of warnings about AOTriton, which you can ignore.Original README below
To run FSDP2 on transformer model:
cd distributed/FSDP2
torchrun --nproc_per_node 2 train.py
- For 1st time, it creates a "checkpoints" folder and saves state dicts there
- For 2nd time, it loads from previous checkpoints
To enable explicit prefetching
torchrun --nproc_per_node 2 train.py --explicit-prefetch
To enable mixed precision
torchrun --nproc_per_node 2 train.py --mixed-precision
To showcase DCP API
torchrun --nproc_per_node 2 train.py --dcp-api
see https://pytorch.org/get-started/locally/ to install at least 2.5 and ideally a current nightly build.