-
Notifications
You must be signed in to change notification settings - Fork 219
feat: refactor init of dtensor policy v2 #1709
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
935ed9c to
23b5525
Compare
Signed-off-by: Hemil Desai <hemild@nvidia.com>
4f66b8f to
81174e5
Compare
|
📝 WalkthroughWalkthroughThis PR introduces a modular setup utilities module for automodel-based training in NeMo RL, encompassing runtime configuration validation, model state management, distributed training orchestration, and model/optimizer initialization. The existing DTensorPolicyWorkerV2 initialization logic is refactored to leverage these new setup functions, replacing monolithic in-place construction with a staged pipeline. Changes
Sequence DiagramsequenceDiagram
participant Config as PolicyConfig
participant Validator as validate_and_prepare_config
participant Distributor as setup_distributed
participant RefModel as setup_reference_model_state
participant ModelSetup as setup_model_and_optimizer
participant Worker as DTensorPolicyWorkerV2
Worker->>Validator: Pass config + processor + rank
Validator->>Validator: Derive dtype, attn_impl, seq_packing
Validator-->>Worker: Return RuntimeConfig
Worker->>Distributor: Pass config + runtime_config
Distributor->>Distributor: Init distributed, FSDP2Manager
Distributor-->>Worker: Return distributed_manager
Worker->>RefModel: Pass model
RefModel->>RefModel: CPU-pin state dict
RefModel-->>Worker: Return reference_model_state_dict
Worker->>ModelSetup: Pass config, tokenizer, runtime_config, distributed_manager
ModelSetup->>ModelSetup: Build model + LoRA
ModelSetup->>ModelSetup: Apply parallelism, load weights
ModelSetup->>ModelSetup: Initialize optimizer + scheduler
ModelSetup-->>Worker: Return ModelAndOptimizerState
Worker->>Worker: Populate attributes from state objects
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In @nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py:
- Around line 190-194: The code calls validate_and_prepare_config(config=config,
processor=processor, rank=0) and leaves rank=0 with a misleading comment; either
compute the real rank from the distributed init and pass it into
validate_and_prepare_config (or re-call validate_and_prepare_config after
distributed initialization with the actual rank) or remove/update the comment to
reflect that rank is not used beyond prints. Update the call site that sets
runtime_config and/or the surrounding comment so validate_and_prepare_config
receives the real rank (or the comment accurately states why a placeholder is
acceptable).
- Around line 186-213: The call to _init_checkpoint_manager is happening before
mesh attributes are created and thus can access undefined
self.dp_mesh/self.tp_mesh/self.moe_mesh; move the _init_checkpoint_manager(...)
invocation to after setup_distributed returns and after the code that assigns
self.dp_mesh, self.tp_mesh, and self.moe_mesh (i.e., once distributed_manager is
used to set those mesh attributes), keeping the same config_updates payload and
using the already-prepared runtime_config/distributed_manager context.
🧹 Nitpick comments (7)
tests/unit/models/automodel/test_automodel_setup.py (1)
22-26: Consider removing unusedpytest_pluginsdeclaration.The
pytest_plugins = []on line 22 is unused. Also, Ruff flags thenoqa: F401directive as unused since that rule isn't enabled, but the directive is harmless and may be needed for other linting tools.♻️ Optional cleanup
-pytest_plugins = [] try: - import nemo_automodel # noqa: F401 + import nemo_automodel except ImportError: pytest.skip("nemo_automodel not available", allow_module_level=True)nemo_rl/models/automodel/setup.py (5)
403-406: Consider using explicit exception instead of assert for configuration validation.Using
assertfor configuration validation can be disabled with-Oflag. For user-facing configuration errors, prefer explicit exceptions.if tp_size > 1: - assert not lora_cfg["use_triton"], ( - "Triton is not supported when tensor_parallel_size > 1" - ) + if lora_cfg["use_triton"]: + raise ValueError( + "Triton is not supported when tensor_parallel_size > 1" + )
516-516: Consider using logging instead of print for model architecture output.
print(model)can produce very verbose output for large models. Consider using a logger with an appropriate level (e.g., DEBUG) or making this conditional on a verbosity setting.
582-586: Default scheduler uses identity lambda with unused parameter.The default LambdaLR scheduler uses
lambda epoch: 1which Ruff flags for the unusedepochparameter. This is functionally correct but could use_to indicate the parameter is intentionally unused.# Default to passthrough LR schedule scheduler = torch.optim.lr_scheduler.LambdaLR( - optimizer, lr_lambda=lambda epoch: 1 + optimizer, lr_lambda=lambda _: 1 )
459-478: Consider usingValueErrorinstead ofAssertionErrorfor configuration validation.These raise
AssertionErrorfor configuration validation, butValueErrorwould be more appropriate and consistent with other validation errors in this module. Assertions can be disabled with Python's-Oflag.if cp_size > 1: if isinstance(model, Gemma3ForCausalLM): - raise AssertionError( + raise ValueError( "Context parallel is not supported for Gemma3ForCausalLM. ..." ) if tp_size > 1 and sequence_parallel_enabled: - raise AssertionError( + raise ValueError( "It's a known issue that context parallel can't be used together with sequence parallel in DTensor worker. ..." ) if is_vlm: - raise AssertionError( + raise ValueError( "Context parallel is yet not supported for VLM models. ..." )
336-340: Calling private method_setup_distributed()on external FSDP2Manager.The code calls
manager._setup_distributed()to force distributed setup for single-GPU cases, as documented in the comment. Consider adding a reference to the FSDP2Manager implementation or documenting why the manager skips setup for world_size=1 to help future maintainers understand the workaround.nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)
252-266: Redundant tuple unpacking with duplicate assignments.
model_classandmodel_configare assigned twice - first frommodel_and_optimizer_state(lines 237-238) and again fromruntime_config(lines 254-255). The_runtime_is_reward_modelvariable on line 265 is also never used after assignment.♻️ Proposed cleanup
# Set instance attributes from runtime config (tuple unpacking) ( - self.model_class, # Already set above, but includes in tuple for completeness - self.model_config, # Already set above, but includes in tuple for completeness + _, # model_class already set from model_and_optimizer_state + _, # model_config already set from model_and_optimizer_state self.hf_config_overrides, self.allow_flash_attn_args, self.attn_impl, self.dtype, self.enable_seq_packing, self.max_grad_norm, self.cpu_offload, self.offload_optimizer_for_logprob, self.is_generation_colocated, - _runtime_is_reward_model, # Duplicate, already set as _is_reward_model + _, # is_reward_model already set as _is_reward_model ) = runtime_config
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
nemo_rl/models/automodel/__init__.pynemo_rl/models/automodel/setup.pynemo_rl/models/policy/workers/dtensor_policy_worker_v2.pytests/unit/models/automodel/__init__.pytests/unit/models/automodel/test_automodel_setup.py
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code
Files:
tests/unit/models/automodel/__init__.pynemo_rl/models/policy/workers/dtensor_policy_worker_v2.pynemo_rl/models/automodel/__init__.pytests/unit/models/automodel/test_automodel_setup.pynemo_rl/models/automodel/setup.py
!(**/tests/**|**/test_*.py|**/test_*.sh)
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year
Files:
tests/unit/models/automodel/__init__.pynemo_rl/models/policy/workers/dtensor_policy_worker_v2.pynemo_rl/models/automodel/__init__.pytests/unit/models/automodel/test_automodel_setup.pynemo_rl/models/automodel/setup.py
**/*.{py,sh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)
Files:
tests/unit/models/automodel/__init__.pynemo_rl/models/policy/workers/dtensor_policy_worker_v2.pynemo_rl/models/automodel/__init__.pytests/unit/models/automodel/test_automodel_setup.pynemo_rl/models/automodel/setup.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes
Files:
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.pynemo_rl/models/automodel/__init__.pynemo_rl/models/automodel/setup.py
🧠 Learnings (3)
📚 Learning: 2025-11-24T17:24:41.976Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to **/*.{py,sh} : The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)
Applied to files:
tests/unit/models/automodel/__init__.pynemo_rl/models/automodel/__init__.py
📚 Learning: 2025-11-24T17:24:41.976Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to !(**/tests/**|**/test_*.py|**/test_*.sh) : Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year
Applied to files:
tests/unit/models/automodel/__init__.pynemo_rl/models/automodel/__init__.py
📚 Learning: 2025-09-17T01:52:21.399Z
Learnt from: ffrujeri
Repo: NVIDIA-NeMo/RL PR: 1023
File: nemo_rl/utils/checkpoint.py:58-65
Timestamp: 2025-09-17T01:52:21.399Z
Learning: model_state_dict_keys is not intended to be part of the nemo-rl CheckpointingConfig TypedDict - it's handled at the automodel implementation layer, not as a general checkpointing configuration parameter.
Applied to files:
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.pytests/unit/models/automodel/test_automodel_setup.py
🧬 Code graph analysis (2)
tests/unit/models/automodel/test_automodel_setup.py (1)
nemo_rl/models/automodel/setup.py (3)
ModelAndOptimizerState(81-98)RuntimeConfig(49-78)validate_and_prepare_config(101-249)
nemo_rl/models/automodel/setup.py (3)
nemo_rl/models/policy/workers/dtensor_policy_worker.py (1)
get_cpu_state_dict(103-133)nemo_rl/models/policy/utils.py (2)
configure_dynamo_cache(261-268)resolve_model_class(179-183)nemo_rl/utils/automodel_checkpoint.py (2)
set_model_state_dict_keys(192-200)load_base_model(202-244)
🪛 Ruff (0.14.10)
tests/unit/models/automodel/test_automodel_setup.py
24-24: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
85-85: Unused method argument: mock_dynamo
(ARG002)
117-117: Unused method argument: mock_dynamo
(ARG002)
118-118: Unused method argument: mock_resolve_class
(ARG002)
119-119: Unused method argument: mock_autoconfig_class
(ARG002)
137-137: Unused method argument: mock_dynamo
(ARG002)
138-138: Unused method argument: mock_resolve_class
(ARG002)
139-139: Unused method argument: mock_autoconfig_class
(ARG002)
161-161: Unused method argument: mock_rm_class
(ARG002)
162-162: Unused method argument: mock_dynamo
(ARG002)
163-163: Unused method argument: mock_resolve_class
(ARG002)
193-193: Unused method argument: mock_dynamo
(ARG002)
194-194: Unused method argument: mock_resolve_class
(ARG002)
195-195: Unused method argument: mock_autoconfig_class
(ARG002)
216-216: Unused method argument: mock_dynamo
(ARG002)
217-217: Unused method argument: mock_resolve_class
(ARG002)
218-218: Unused method argument: mock_autoconfig_class
(ARG002)
237-237: Unused method argument: mock_dynamo
(ARG002)
269-269: Unused method argument: mock_dynamo
(ARG002)
300-300: Unused method argument: mock_dynamo
(ARG002)
323-323: Unused method argument: mock_dynamo
(ARG002)
345-345: Unused method argument: mock_dynamo
(ARG002)
nemo_rl/models/automodel/setup.py
139-139: Avoid specifying long messages outside the exception class
(TRY003)
152-155: Avoid specifying long messages outside the exception class
(TRY003)
209-209: Avoid specifying long messages outside the exception class
(TRY003)
220-223: Avoid specifying long messages outside the exception class
(TRY003)
231-234: Avoid specifying long messages outside the exception class
(TRY003)
462-466: Avoid specifying long messages outside the exception class
(TRY003)
469-473: Avoid specifying long messages outside the exception class
(TRY003)
476-478: Avoid specifying long messages outside the exception class
(TRY003)
585-585: Unused lambda argument: epoch
(ARG005)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: sphinx-build / Build docs
- GitHub Check: build-container / main
- GitHub Check: Lint check
- GitHub Check: Lint check
- GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (10)
nemo_rl/models/automodel/__init__.py (1)
1-13: LGTM!Package initializer with proper NVIDIA copyright header. The license header is correctly formatted with the appropriate year.
tests/unit/models/automodel/__init__.py (1)
1-13: LGTM!Test package initializer is correctly structured. Note that per coding guidelines, test files are exempt from the copyright header requirement, but having it doesn't cause any issues.
tests/unit/models/automodel/test_automodel_setup.py (2)
76-110: LGTM!The test suite for
validate_and_prepare_configprovides good coverage of the core validation logic. The unused mock arguments (flagged by Ruff) are intentional - they're required to receive patched values from decorators even when not directly referenced in the test body.
362-418: LGTM!Tests for
ModelAndOptimizerStateverify both full initialization and optional field handling correctly.nemo_rl/models/automodel/setup.py (4)
48-79: LGTM!
RuntimeConfigis well-structured with clear field groupings and appropriate type hints. Good use of NamedTuple for an immutable configuration container.
80-99: LGTM!
ModelAndOptimizerStateprovides a clean container for model initialization results with appropriate optional type annotations.
126-131: Clarify intent whenis_generation_colocatedisNone.The condition
if not is_generation_colocatedevaluates toTruefor bothNoneandFalse. This meansNCCL_CUMEM_ENABLE=1is set when there's no generation config or when generation is explicitly not colocated. If this is intentional, consider making it explicit for clarity.- if not is_generation_colocated: + # Set NCCL_CUMEM_ENABLE when generation is not colocated (False) or not configured (None) + if is_generation_colocated is not True: os.environ["NCCL_CUMEM_ENABLE"] = "1"
252-271: LGTM!
setup_reference_model_stateis well-documented and provides efficient reference model state management with pinned memory for optimal CPU-GPU transfers.nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (2)
1660-1684: LGTM on method implementation.The
_init_checkpoint_managermethod itself is well-structured. The issue is the calling order in__init__(addressed in previous comment).
243-250: LGTM!Reference model initialization and derived attribute setup are clean and follow the new modular pattern correctly.
| runtime_config = validate_and_prepare_config( | ||
| config=config, | ||
| processor=processor, | ||
| rank=0, # Temporary, will be updated after distributed init | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rank=0 passed to validate_and_prepare_config is never updated.
The comment says "Temporary, will be updated after distributed init" but the rank value passed to validate_and_prepare_config is never updated. The rank is only used for print statements in that function, so this isn't functionally broken, but the comment is misleading.
# Validate configuration and prepare runtime settings
runtime_config = validate_and_prepare_config(
config=config,
processor=processor,
- rank=0, # Temporary, will be updated after distributed init
+ rank=0, # Rank is only used for log messages before distributed init
)🤖 Prompt for AI Agents
In @nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py around lines 190 -
194, The code calls validate_and_prepare_config(config=config,
processor=processor, rank=0) and leaves rank=0 with a misleading comment; either
compute the real rank from the distributed init and pass it into
validate_and_prepare_config (or re-call validate_and_prepare_config after
distributed initialization with the actual rank) or remove/update the comment to
reflect that rank is not used beyond prints. Update the call site that sets
runtime_config and/or the surrounding comment so validate_and_prepare_config
receives the real rank (or the comment accurately states why a placeholder is
acceptable).
|
|
|
joyang-nv
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good stuff! Thanks for cleaning up these. @hemildesai
|
yuki-97
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thanks for the efforts! left some minor comments.
| # Disable dynamo autotune_local_cache to avoid crash when there's already a cache | ||
| # with different order of node_bundles | ||
| configure_dynamo_cache() | ||
| is_vlm = processor is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: how about just set self.is_vlm and self.lora_enabled here?
so that there's no need for Additional derived attributes below, also we can just use "is_peft": self.lora_enabled.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in fix
|
|
||
| # Disable dynamo autotune_local_cache to avoid crash when there's already a cache | ||
| # with different order of node_bundles | ||
| configure_dynamo_cache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feels better to keep the two comments in setup.py, so that others can know why we need to do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in fix
| "flash_attention_2" | ||
| if (self.enable_seq_packing and cp_size_cfg == 1) | ||
| else ("sdpa" if cp_size_cfg > 1 else None) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to keep comments here as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in fix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change lgtm. i have a request to maintain the comments in the policy init when migrated to setup.py
|
|
||
| # Disable dynamo autotune_local_cache to avoid crash when there's already a cache | ||
| # with different order of node_bundles | ||
| configure_dynamo_cache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed
|
@hemildesai also, please run dtensor_v2 nightlies to check this refactor didn't introduce regression |
|
Refactors init of dtensor policy v2 as part of #1589. Depends on #1695
Issues
#1589
Nightly links:
dpo - https://wandb.ai/nvidia/nemo-rl/runs/zx0io3io?nw=nwuserhemild
grpo moonlight - https://wandb.ai/nvidia/nemo-rl/runs/elsaxv44?nw=nwuserhemild
grpo qwen - https://wandb.ai/nvidia/nemo-rl/runs/kso90996?nw=nwuserhemild
sft gpt-oss - https://wandb.ai/nvidia/ruit_personal_debug/runs/g1lpu7e1
Summary by CodeRabbit
New Features
Refactor
Tests
✏️ Tip: You can customize this high-level summary in your review settings.