Skip to content

Conversation

@senarvi
Copy link
Contributor

@senarvi senarvi commented Dec 14, 2025

What does this PR do?

The WeightAveraging callback doesn't support sharded models. The reason is that either the averaged model should be sharded too, or the full model parameters are needed when creating and updating the averaged model. There was a lot of interest in using EMA with FSDP, but this was left out from the original PR, because it's not obvious how to implement it.

@amorehead noticed that SimpleFold uses Lightning, AveragedModel, and FSDP. They simply summon full parameters before updating the averaged model. That's what this PR does.

The full parameters are also needed when creating the averaged model and when swapping the current and the averaged model for validation. I call pl_module.configure_model() in setup(), meaning that the full parameters are initialized in CPU memory. SimpleFold doesn't define configure_model() at all, so I believe the result is the same. When updating the averaged model, SimpleFold doesn't use offload_to_cpu, so I don't use it either. If the entire model doesn't fit in the GPU memory, you'll run out of memory at this point.

This is probably the best we can do without massive changes. Is this good enough? I don't know, I've never used FSDP. Maybe someone who has an actual use case could check if this is useful. Tagging people who asked about this in the original PR @amorehead @kzrpg @npuichigo

Before submitting
  • Was this discussed/agreed via a GitHub issue? -- Not agreed, but there was a lot of interest in the original PR.
  • Did you make sure to update the documentation with your changes? -- Updated the documentation of the WeightAveraging class.
  • Did you write any new necessary tests? (not for typos and docs) -- Added one test for EMA with FSDP. Feel free to suggest more.

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--21414.org.readthedocs.build/en/21414/

@github-actions github-actions bot added pl Generic label for PyTorch Lightning package has conflicts labels Dec 14, 2025
@senarvi senarvi force-pushed the weight-averaging-fsdp branch from d795592 to 241a95f Compare December 14, 2025 16:08
@codecov
Copy link

codecov bot commented Dec 14, 2025

❌ 1 Tests Failed:

Tests completed Failed Passed Skipped
3206 1 3205 579
View the full list of 1 ❄️ flaky test(s)
tests/tests_pytorch/utilities/test_imports.py::test_import_pytorch_lightning_with_torch_dist_unavailable

Flake rate in main: 4.00% (Passed 24 times, Failed 1 times)

Stack Traces | 3.64s run time
def test_import_pytorch_lightning_with_torch_dist_unavailable():
        """Test that the package can be imported regardless of whether torch.distributed is available."""
        code = dedent(
            """
            import torch
            try:
               # PyTorch 2.5 relies on torch,distributed._composable.fsdp not
               # existing with USE_DISTRIBUTED=0
               import torch._dynamo.variables.functions
               torch._dynamo.variables.functions._fsdp_param_group = None
            except ImportError:
               pass
    
            # pretend torch.distributed not available
            for name in list(torch.distributed.__dict__.keys()):
                if not name.startswith("__"):
                    delattr(torch.distributed, name)
    
            torch.distributed.is_available = lambda: False
    
            # needed for Dynamo in PT 2.5+ compare the torch.distributed source
            class _ProcessGroupStub:
                pass
            torch.distributed.ProcessGroup = _ProcessGroupStub
    
            import pytorch_lightning
            """
        )
        # run in complete isolation
>       assert subprocess.call([sys.executable, "-c", code]) == 0
E       assert 1 == 0
E        +  where 1 = <function call at 0x7f8ece7d3b50>(['.../pytorch-lightning/pytorch-lightning/.venv/bin/python', '-c', '\nimport torch\ntry:\n   # PyTorch 2.5 relies on torch,distributed._composable.fsdp not\n   # existing with USE_DISTRIBUTED=0\n   import torch._dynamo.variables.functions\n   torch._dynamo.variables.functions._fsdp_param_group = None\nexcept ImportError:\n   pass\n\n# pretend torch.distributed not available\nfor name in list(torch.distributed.__dict__.keys()):\n    if not name.startswith("__"):\n        delattr(torch.distributed, name)\n\ntorch.distributed.is_available = lambda: False\n\n# needed for Dynamo in PT 2.5+ compare the torch.distributed source\nclass _ProcessGroupStub:\n    pass\ntorch.distributed.ProcessGroup = _ProcessGroupStub\n\nimport pytorch_lightning\n'])
E        +    where <function call at 0x7f8ece7d3b50> = subprocess.call

utilities/test_imports.py:144: AssertionError

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

@amorehead
Copy link
Contributor

Thanks for leading this effort, @senarvi! I'm sure this feature will be useful to the PyTorch community in the coming years. After glancing through the code changes, they look good to me. As long as the revised EMA unit test passes, it should be good to go.

@bhimrazy bhimrazy requested a review from deependujha December 16, 2025 07:08
@senarvi senarvi force-pushed the weight-averaging-fsdp branch from d1f40af to 675ba32 Compare December 16, 2025 07:16
@github-actions github-actions bot added the docs Documentation related label Dec 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

docs Documentation related has conflicts pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants