From 97d757a1f6ddcdc08e46d6833f0c3d888c4dd811 Mon Sep 17 00:00:00 2001 From: Tyler Griggs Date: Tue, 20 Jan 2026 17:20:05 +0000 Subject: [PATCH 1/5] Unify Megatron and FSDP training interfaces with forward_backward + optim_step - Add forward_backward() and optim_step() methods to MegatronPolicyWorkerBase - Update trainer to use unified interface for both strategies - Remove strategy branching in train_critic_and_policy() - Mark ppo_train() as deprecated (kept for backward compatibility) - Update test_megatron_worker.py to use new interface Co-Authored-By: Eric Tang Co-Authored-By: Claude Opus 4.5 --- skyrl-train/skyrl_train/trainer.py | 23 ++-- .../workers/megatron/megatron_worker.py | 113 +++++++++++++++++- .../skyrl_train/workers/worker_dispatch.py | 5 +- .../tests/gpu/gpu_ci/test_megatron_worker.py | 9 +- 4 files changed, 127 insertions(+), 23 deletions(-) diff --git a/skyrl-train/skyrl_train/trainer.py b/skyrl-train/skyrl_train/trainer.py index e570e6df4..5da35ef19 100644 --- a/skyrl-train/skyrl_train/trainer.py +++ b/skyrl-train/skyrl_train/trainer.py @@ -1074,26 +1074,17 @@ def train_critic_and_policy(self, data: TrainingInputBatch): """ Run the training step for the policy and critic models. - For Megatron strategy: uses ppo_train (training loop inside worker) - For FSDP strategy: uses forward_backward + optim_step (training loop in trainer) + Uses forward_backward + optim_step for both FSDP and Megatron strategies. """ data.metadata["global_step"] = self.global_step critic_status = None - if self.cfg.trainer.strategy == "megatron": - # Megatron: training loop inside worker via ppo_train - if self.has_critic: - with Timer("critic_train", self.all_timings): - critic_status = self.dispatch.ppo_train("critic", data) - with Timer("policy_train", self.all_timings): - policy_status = self.dispatch.ppo_train("policy", data) - else: - # FSDP: training loop in trainer via forward_backward + optim_step - if self.has_critic: - with Timer("critic_train", self.all_timings): - critic_status = self._execute_training_step("critic", data) - with Timer("policy_train", self.all_timings): - policy_status = self._execute_training_step("policy", data) + # Unified training interface for both FSDP and Megatron + if self.has_critic: + with Timer("critic_train", self.all_timings): + critic_status = self._execute_training_step("critic", data) + with Timer("policy_train", self.all_timings): + policy_status = self._execute_training_step("policy", data) # Update metrics if critic_status is not None: diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py index 53bb6a709..ecd620189 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py @@ -29,7 +29,7 @@ from skyrl_train.distributed.megatron.megatron_utils import print_model_size, broadcast_object_across_pp_ranks from skyrl_train.utils.utils import update_model_config, str_to_torch_dtype from skyrl_train.utils.constants import SKYRL_WORKER_NCCL_TIMEOUT_IN_S -from skyrl_train.training_batch import TrainingOutputBatch +from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics from skyrl_train.workers.worker import ( PolicyWorkerBase, @@ -519,12 +519,117 @@ def init_model(self, model_path, num_training_steps: int = 1e9): self.empty_cuda_cache = self.cfg.trainer.policy.megatron_config.empty_cuda_cache + def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]: + """ + Perform forward and backward passes for a batch, handling micro-batching internally. + + The batch is split into micro batches based on micro_train_batch_size_per_gpu. + Megatron Core's forward_backward_func handles gradient accumulation internally. + + Args: + data: TrainingInputBatch (already DP-sharded by WorkerDispatch/MeshDispatch) + + Returns: + Aggregated metrics dict across all micro batches + """ + self.model.train() + for chunk in self.actor_module: + # if use distributed optimizer, zero grad buffer will be handled by optimizer + chunk.zero_grad_buffer() + + micro_batch_size = self.cfg.trainer.micro_train_batch_size_per_gpu + all_metrics = defaultdict(list) + + # Move data to GPU + data.to(torch.cuda.current_device()) + + # Build micro-batch dicts expected by forward_backward_mini_batch + micro_buffer = [] + for experience in BatchIterator(data, micro_batch_size, drop_last=False): + sequences = experience.sequences + attention_mask = experience.attention_mask + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 0) + + micro_buffer.append( + { + "sequences": sequences, + "attention_mask": attention_mask, + "position_ids": position_ids, + "num_actions": experience.num_actions, + "old_action_log_probs": experience.action_log_probs, + "base_action_log_probs": experience.base_action_log_probs, + "advantages": experience.advantages, + "loss_mask": experience.loss_mask, + "rollout_action_logprobs": experience.rollout_logprobs, + } + ) + + if not micro_buffer: + return {} + + seq_len = micro_buffer[0]["sequences"].shape[1] + micro_bsz = micro_buffer[0]["sequences"].shape[0] + + metrics_list = self.model.forward_backward_mini_batch( + micro_batches=micro_buffer, + seq_len=seq_len, + micro_batch_size=micro_bsz, + temperature=self.cfg.generator.sampling_params.temperature, + ) + + if self.empty_cuda_cache: + torch.cuda.empty_cache() + + # Track number of micro-batches for metrics + self._micro_batches_accumulated += len(micro_buffer) + + # Aggregate metrics across micro-batches + for metrics in metrics_list: + for k, v in metrics.items(): + all_metrics[k].append(v) + + # Reduce and all-reduce metrics + status = reduce_metrics(dict(all_metrics)) + status["policy_lr"] = self.optimizer.param_groups[0]["lr"] + status = self.strategy.all_reduce(status) + + return status + + def optim_step(self) -> Optional[float]: + """ + Perform optimizer step. + + Note: Unlike FSDP workers, Megatron doesn't need manual gradient scaling here + because Megatron Core's forward_backward_func handles loss scaling internally. + + Returns: + The gradient norm (before scaling, after clipping), or None if unavailable. + """ + grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor") + + # Reset counter for next accumulation cycle + self._micro_batches_accumulated = 0 + + if grad_norm is not None: + grad_norm = grad_norm.detach().cpu().item() if hasattr(grad_norm, "item") else grad_norm + return grad_norm + + def get_lr(self) -> float: + """ + Get current learning rate from optimizer. + + Override base class method because Megatron's OptimizerParamScheduler + doesn't have get_last_lr() like PyTorch schedulers. + """ + return self.optimizer.param_groups[0]["lr"] + def ppo_train(self, train_data) -> "TrainingOutputBatch": """ - Overrides `PolicyWorkerBase.ppo_train` for megatron. + DEPRECATED: Use forward_backward() + optim_step() instead. - Since we want megatron to handle gradient accumulation over micro batches, we directly pass mini batches into the - worker MegatronModelWrapper.forward_backward_mini_batch method. + This method is kept for backward compatibility with existing scripts. + The trainer now uses forward_backward() + optim_step() for both FSDP and Megatron. """ dataloader = BatchIterator( train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False diff --git a/skyrl-train/skyrl_train/workers/worker_dispatch.py b/skyrl-train/skyrl_train/workers/worker_dispatch.py index cff67e703..750f7827c 100644 --- a/skyrl-train/skyrl_train/workers/worker_dispatch.py +++ b/skyrl-train/skyrl_train/workers/worker_dispatch.py @@ -168,7 +168,10 @@ def optim_step(self, model: str) -> Optional[float]: # TODO(tgriggs): Remove this when Megatron supports forward_backward and optim_step. def ppo_train(self, model: str, data: TrainingInputBatch) -> Dict[str, float]: - """Run full PPO training loop (for Megatron).""" + """DEPRECATED: Use forward_backward() + optim_step() instead. + + This method is kept for backward compatibility with existing scripts. + """ self._ensure_on_gpu(model, need_optimizer=True, need_model=True) refs = self._actor_groups[model].async_run_ray_method("mesh", "ppo_train", data) diff --git a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py index 928bd9df4..74fd3f55f 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py @@ -488,10 +488,15 @@ async def test_megatron_train( cfg=cfg, ) + # Use forward_backward + optim_step (unified interface for both megatron and FSDP) with Timer(f"megatron training step tp{tp} pp{pp} cp{cp} ep{ep} etp{etp}"): batch.metadata["global_step"] = 0 - results_megatron = ray.get(actor_group.async_run_ray_method("pass_through", "ppo_train", batch)) - results_megatron = [results_megatron[i].metadata["train_status"] for i in range(len(results_megatron))] + results_megatron = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", batch)) + ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) + # Get learning rate from worker + lr_results = ray.get(actor_group.async_run_ray_method("pass_through", "get_lr")) + for i, result in enumerate(results_megatron): + result["policy_lr"] = lr_results[i] memory = ray.get(actor_group.async_run_ray_method("pass_through", "get_cuda_memory")) memory = memory[0] From 13ce6985dddfdf112a8485f39fa6c5cef8779e52 Mon Sep 17 00:00:00 2001 From: Tyler Griggs Date: Tue, 20 Jan 2026 17:44:03 +0000 Subject: [PATCH 2/5] Remove deprecated ppo_train method, update all tests to use unified interface - Remove ppo_train from MegatronPolicyWorkerBase and WorkerDispatch - Update test_megatron_dp, test_megatron_offload to use forward_backward + optim_step - Update test_save_load_model.py and test_save_load_checkpoint.py for unified interface - Simplify _normalize_mini_batch_size (no longer needs policy_mini_batch_size_per_gpu) Both FSDP and Megatron now use the same forward_backward + optim_step interface. Co-Authored-By: Eric Tang Co-Authored-By: Claude Opus 4.5 --- .../workers/megatron/megatron_worker.py | 142 +----------------- .../skyrl_train/workers/worker_dispatch.py | 13 -- .../tests/gpu/gpu_ci/test_megatron_worker.py | 25 ++- .../gpu/gpu_ci/test_save_load_checkpoint.py | 12 +- skyrl-train/tests/gpu/test_save_load_model.py | 12 +- 5 files changed, 30 insertions(+), 174 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py index ecd620189..bf81bd03e 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py @@ -433,18 +433,13 @@ def _broadcast_no_grad(*args, **kwargs): def _normalize_mini_batch_size(self): """ - Override to set Megatron-specific batch size attributes. + Initialize micro batch tracking for gradient accumulation. - Megatron's ppo_train method needs policy_mini_batch_size_per_gpu to compute - how many micro batches fit in a mini batch for gradient accumulation. + Megatron uses the same interface as FSDP - the base class implementation + sets up _micro_batches_accumulated for tracking. """ super()._normalize_mini_batch_size() # Sets _micro_batches_accumulated - # Megatron-specific: compute mini batch size per GPU for ppo_train - n_samples = self.cfg.generator.n_samples_per_prompt - dp_size = self.mesh_rank.dp_size - self.policy_mini_batch_size_per_gpu = (self.cfg.trainer.policy_mini_batch_size * n_samples) // dp_size - def init_model(self, model_path, num_training_steps: int = 1e9): """ Initialize the model, optimizer, and scheduler for the policy worker. @@ -624,137 +619,6 @@ def get_lr(self) -> float: """ return self.optimizer.param_groups[0]["lr"] - def ppo_train(self, train_data) -> "TrainingOutputBatch": - """ - DEPRECATED: Use forward_backward() + optim_step() instead. - - This method is kept for backward compatibility with existing scripts. - The trainer now uses forward_backward() + optim_step() for both FSDP and Megatron. - """ - dataloader = BatchIterator( - train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False - ) - - micro_batches_per_mini_batch = ( - self.policy_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu - ) - - status_list = [] - all_metrics = defaultdict(list) - policy_update_steps = 0 - - if self.profiler is not None: - self.profiler.start() - - for epoch in range(self.cfg.trainer.update_epochs_per_batch): - self.optimizer.zero_grad() - pbar = tqdm( - dataloader, - desc=f"Policy Train epoch [{epoch + 1}/{self.cfg.trainer.update_epochs_per_batch}]", - disable=not self.strategy.is_rank_0(), - ) - - # TODO: Convert this into 2 loops for minibatches and microbatches. - micro_buffer = [] - for local_step, experience in enumerate(pbar): - # BatchIterator now yields Experience objects directly - experience.to_device(torch.cuda.current_device()) - sequences = experience.sequences - attention_mask = experience.attention_mask - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 0) - - micro_buffer.append( - { - "sequences": sequences, - "attention_mask": attention_mask, - "position_ids": position_ids, - "num_actions": experience.num_actions, - "old_action_log_probs": experience.action_log_probs, - "base_action_log_probs": experience.base_action_log_probs, - "advantages": experience.advantages, - "loss_mask": experience.loss_mask, - "rollout_action_logprobs": experience.rollout_logprobs, - } - ) - - if len(micro_buffer) == micro_batches_per_mini_batch: - # run mini-batch forward-backward and then one optimizer step - self.model.train() - for chunk in self.actor_module: - # if use distributed optimizer, zero grad buffer will be handled by optimizer - chunk.zero_grad_buffer() - seq_len = micro_buffer[0]["sequences"].shape[1] - micro_bsz = micro_buffer[0]["sequences"].shape[0] - - metrics_list = self.model.forward_backward_mini_batch( - micro_batches=micro_buffer, - seq_len=seq_len, - micro_batch_size=micro_bsz, - temperature=self.cfg.generator.sampling_params.temperature, - ) - - if self.empty_cuda_cache: - torch.cuda.empty_cache() - - grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor") - - # within a DP group, metrics are already the same across all workers - we then just all reduce across - # the whole world size to get the metrics for the global micro batch - for i, metrics in enumerate(metrics_list): - status = { - "final_loss": metrics["final_loss"], - "policy_loss": metrics["policy_loss"], - "policy_lr": self.optimizer.param_groups[0]["lr"], - "ppo_clip_ratio": metrics["ppo_clip_ratio"], - "policy_entropy": metrics["policy_entropy"], - } - if self.cfg.trainer.algorithm.use_kl_loss: - status["policy_kl"] = metrics["policy_kl"] - - # Attach grad norm only for the last micro in the mini-batch - if i == len(metrics_list) - 1 and grad_norm is not None: - status["raw_grad_norm"] = grad_norm - - # attach response_length - status["response_length"] = micro_buffer[i]["num_actions"] - - status = self.strategy.all_reduce(status) - status_list.append(status) - for k, v in status.items(): - all_metrics[k].append(v) - - short_status = { - "pg": status_list[-1]["policy_loss"], - "glen": status_list[-1]["response_length"], - "policy_lr": status_list[-1]["policy_lr"], - "ent": status_list[-1]["policy_entropy"], - } - if "raw_grad_norm" in status_list[-1]: - short_status["grad_norm"] = status_list[-1]["raw_grad_norm"] - pbar.set_postfix(short_status) - - policy_update_steps += 1 - micro_buffer = [] - - # drop any trailing micros that don't fill a mini-batch (keep behavior consistent) - micro_buffer = [] - - torch.distributed.barrier() - if self.profiler is not None: - self.profiler.stop_and_save() - self.profiler.stop_trace() - - # not needed beyond status logging - all_metrics.pop("response_length", None) - - status_mean = reduce_metrics(all_metrics) - status_mean["policy_update_steps"] = policy_update_steps - - output = TrainingOutputBatch() - output.metadata = {"train_status": status_mean} - return output - async def broadcast_to_inference_engines(self, inference_engine_client): use_prefix_cache = self.cfg.generator.enable_prefix_caching generator_dtype = str_to_torch_dtype(self.cfg.generator.model_dtype) diff --git a/skyrl-train/skyrl_train/workers/worker_dispatch.py b/skyrl-train/skyrl_train/workers/worker_dispatch.py index 750f7827c..a34af8dc2 100644 --- a/skyrl-train/skyrl_train/workers/worker_dispatch.py +++ b/skyrl-train/skyrl_train/workers/worker_dispatch.py @@ -166,19 +166,6 @@ def optim_step(self, model: str) -> Optional[float]: self._save_memory_snapshot(model, "optim_step") return grad_norms[0] - # TODO(tgriggs): Remove this when Megatron supports forward_backward and optim_step. - def ppo_train(self, model: str, data: TrainingInputBatch) -> Dict[str, float]: - """DEPRECATED: Use forward_backward() + optim_step() instead. - - This method is kept for backward compatibility with existing scripts. - """ - self._ensure_on_gpu(model, need_optimizer=True, need_model=True) - - refs = self._actor_groups[model].async_run_ray_method("mesh", "ppo_train", data) - statuses = ray.get(refs) - - return statuses[0].metadata["train_status"] - def _save_memory_snapshot(self, model: str, tag: str) -> None: """Save memory snapshot on workers.""" ray.get( diff --git a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py index 74fd3f55f..af4ad17e0 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_megatron_worker.py @@ -536,7 +536,7 @@ async def test_megatron_train( cfg=cfg, ) - # FSDP uses forward_backward + optim_step instead of ppo_train + # Both FSDP and Megatron use forward_backward + optim_step (unified interface) batch.metadata["global_step"] = 0 results_fsdp = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", batch)) ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) @@ -602,10 +602,13 @@ async def test_megatron_dp(ray_init_fixture, worker_type, tp, pp, gpus_per_node) cfg=cfg, ) - # call ppo_train with a batch of size 4 per gpu + # Use forward_backward + optim_step (unified interface) batch.metadata["global_step"] = 0 - results_megatron = ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", batch)) - results_megatron = [results_megatron[i].metadata["train_status"] for i in range(len(results_megatron))] + results_megatron = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", batch)) + ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) + lr_results = ray.get(actor_group.async_run_ray_method("pass_through", "get_lr")) + for i, result in enumerate(results_megatron): + result["policy_lr"] = lr_results[i] memory = ray.get(actor_group.async_run_ray_method("pass_through", "get_cuda_memory")) memory = memory[0] @@ -644,8 +647,11 @@ async def test_megatron_dp(ray_init_fixture, worker_type, tp, pp, gpus_per_node) cfg=cfg, ) - results_megatron_dp = ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", batch)) - results_megatron_dp = [results_megatron_dp[i].metadata["train_status"] for i in range(len(results_megatron_dp))] + results_megatron_dp = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", batch)) + ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) + lr_results_dp = ray.get(actor_group.async_run_ray_method("pass_through", "get_lr")) + for i, result in enumerate(results_megatron_dp): + result["policy_lr"] = lr_results_dp[i] print("megatron results: ", results_megatron) print("\n\n") @@ -716,7 +722,9 @@ async def test_megatron_offload_memory_and_correctness(ray_init_fixture, worker_ get_rank_0_memory(actor_group, "Before training") batch = get_test_training_batch() - results = ray.get(actor_group.async_run_ray_method("pass_through", "ppo_train", batch)) + # Use forward_backward + optim_step (unified interface) + results = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", batch)) + ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) after_training = get_rank_0_memory(actor_group, "After training") @@ -761,7 +769,8 @@ async def test_megatron_offload_memory_and_correctness(ray_init_fixture, worker_ get_rank_0_memory(actor_group, "After backload") # Run training again and ensure output consistency - results_backload = ray.get(actor_group.async_run_ray_method("pass_through", "ppo_train", batch)) + results_backload = ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", batch)) + ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) for i, result in enumerate(results): result_backload = results_backload[i] diff --git a/skyrl-train/tests/gpu/gpu_ci/test_save_load_checkpoint.py b/skyrl-train/tests/gpu/gpu_ci/test_save_load_checkpoint.py index 7918bb4b0..a1dc7d714 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_save_load_checkpoint.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_save_load_checkpoint.py @@ -32,13 +32,11 @@ def run_one_training_step( megatron_batch=None, ): """Run forward_backward + optim_step to perform one training step.""" - if strategy == "megatron": - assert megatron_batch is not None, "Megatron requires a TrainingInputBatch for ppo_train" - return ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", megatron_batch)) - else: - assert data is not None, f"{strategy} requires a TrainingInputBatch for forward_backward" - ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=data)) - ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) + # Unified interface for all strategies (megatron, fsdp, fsdp2) + batch = megatron_batch if strategy == "megatron" else data + assert batch is not None, f"{strategy} requires a TrainingInputBatch for forward_backward" + ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=batch)) + ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) def get_test_actor_config(strategy: str) -> DictConfig: diff --git a/skyrl-train/tests/gpu/test_save_load_model.py b/skyrl-train/tests/gpu/test_save_load_model.py index 71593cb86..3577b26b1 100644 --- a/skyrl-train/tests/gpu/test_save_load_model.py +++ b/skyrl-train/tests/gpu/test_save_load_model.py @@ -58,13 +58,11 @@ def run_one_training_step( megatron_batch=None, ): """Run forward_backward + optim_step to perform one training step.""" - if strategy == "megatron": - assert megatron_batch is not None, "Megatron requires a TrainingInputBatch for ppo_train" - return ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", megatron_batch)) - else: - assert data is not None, f"{strategy} requires a TrainingInputBatch for forward_backward" - ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=data)) - ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) + # Unified interface for all strategies (megatron, fsdp, fsdp2) + batch = megatron_batch if strategy == "megatron" else data + assert batch is not None, f"{strategy} requires a TrainingInputBatch for forward_backward" + ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=batch)) + ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) @pytest.mark.parametrize( From 7adcb3dc4d26ac3fcdba1b18063697e01e23b0d0 Mon Sep 17 00:00:00 2001 From: Tyler Griggs Date: Tue, 20 Jan 2026 18:29:52 +0000 Subject: [PATCH 3/5] format --- skyrl-train/skyrl_train/workers/megatron/megatron_worker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py index bf81bd03e..78afb34ce 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py @@ -9,7 +9,6 @@ from datetime import timedelta from typing import List, Dict, Any, Optional from collections import defaultdict -from tqdm import tqdm from omegaconf import OmegaConf from megatron.bridge import AutoBridge From 879ac2149d28ac5c32271231ac0215c465c1e672 Mon Sep 17 00:00:00 2001 From: Tyler Griggs Date: Tue, 20 Jan 2026 20:03:12 +0000 Subject: [PATCH 4/5] Remove _init_gradient_accumulation_state method, initialize in __init__ The method just set _micro_batches_accumulated = 0, which can be done directly in __init__. This removes unnecessary indirection and the vestigial mesh_rank guard that was no longer needed. Co-Authored-By: Claude Opus 4.5 --- .../skyrl_train/workers/fsdp/fsdp_worker.py | 6 -- .../workers/megatron/megatron_worker.py | 11 --- skyrl-train/skyrl_train/workers/worker.py | 34 ------- skyrl-train/tests/cpu/test_trainer.py | 97 +++++-------------- 4 files changed, 26 insertions(+), 122 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py b/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py index 285ef618c..a3ade9879 100644 --- a/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py +++ b/skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py @@ -113,9 +113,6 @@ def init_model(self, model_path, num_training_steps: int = None): self._is_lora = self.cfg.trainer.policy.model.lora.rank > 0 - # Update per-gpu mini batch size based on device mesh - self._normalize_mini_batch_size() - model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) init_context = get_init_weight_context_manager( use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh @@ -276,9 +273,6 @@ def init_model(self, model_path, num_training_steps: int = None): strategy.setup_distributed() self.strategy = strategy - # Update per-gpu mini batch size based on device mesh - self._normalize_mini_batch_size() - model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) init_context = get_init_weight_context_manager( use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py index 78afb34ce..463c7e96e 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py @@ -430,15 +430,6 @@ def _broadcast_no_grad(*args, **kwargs): pp_size=mpu.get_pipeline_model_parallel_world_size(), ) - def _normalize_mini_batch_size(self): - """ - Initialize micro batch tracking for gradient accumulation. - - Megatron uses the same interface as FSDP - the base class implementation - sets up _micro_batches_accumulated for tracking. - """ - super()._normalize_mini_batch_size() # Sets _micro_batches_accumulated - def init_model(self, model_path, num_training_steps: int = 1e9): """ Initialize the model, optimizer, and scheduler for the policy worker. @@ -481,8 +472,6 @@ def init_model(self, model_path, num_training_steps: int = 1e9): ) self.optimizer = get_megatron_optimizer(self.actor_module, optim_config) - self._normalize_mini_batch_size() - # create scheduler self.scheduler = get_megatron_optimizer_param_scheduler( optimizer=self.optimizer, diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 912eb2dc3..08b2a1ecf 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -639,23 +639,6 @@ def __init__(self, **kwargs): self.record_memory: bool = False self.mesh_rank: MeshRank = None self.policy_loss_fn: Callable = PolicyLossRegistry.get(self.cfg.trainer.algorithm.policy_loss_type) - - def _normalize_mini_batch_size(self): - """ - Initialize micro batch tracking for gradient accumulation. - - The worker no longer needs to know mini batch size - it processes whatever - batch it receives, breaking it into micro batches. Gradient scaling happens - at optim_step time based on how many micro batches were accumulated. - - TODO: Rename to _init_gradient_accumulation_state once Megatron no longer - requires mini-batch normalization in its override. The name is kept for - backwards compatibility with Megatron which still does actual normalization. - """ - if not hasattr(self, "mesh_rank") or self.mesh_rank is None: - raise RuntimeError("mesh_rank must be initialized before calling _normalize_mini_batch_size()") - - # Track micro batches for gradient scaling at optim_step self._micro_batches_accumulated = 0 def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]: @@ -886,23 +869,6 @@ def __init__(self, **kwargs): self.record_memory: bool = False self.mesh_rank: MeshRank = None self.critic_loss_fn: Callable = ppo_critic_loss - - def _normalize_mini_batch_size(self): - """ - Initialize micro batch tracking for gradient accumulation. - - The worker no longer needs to know mini batch size - it processes whatever - batch it receives, breaking it into micro batches. Gradient scaling happens - at optim_step time based on how many micro batches were accumulated. - - TODO: Rename to _init_gradient_accumulation_state once Megatron no longer - requires mini-batch normalization in its override. The name is kept for - backwards compatibility with Megatron which still does actual normalization. - """ - if not hasattr(self, "mesh_rank") or self.mesh_rank is None: - raise RuntimeError("mesh_rank must be initialized before calling _normalize_mini_batch_size()") - - # Track micro batches for gradient scaling at optim_step self._micro_batches_accumulated = 0 def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]: diff --git a/skyrl-train/tests/cpu/test_trainer.py b/skyrl-train/tests/cpu/test_trainer.py index a704278e6..d45a9efca 100644 --- a/skyrl-train/tests/cpu/test_trainer.py +++ b/skyrl-train/tests/cpu/test_trainer.py @@ -169,15 +169,8 @@ def test_calc_advantages_and_returns(mock_compute_adv_and_ret, dummy_config): ) -def test_normalize_mini_batch_size(): - """Test the _normalize_mini_batch_size method initializes micro batch tracking. - - Workers don't need to know mini batch sizes per GPU. - They receive batches from the trainer and split them into micro batches. - _normalize_mini_batch_size only initializes micro batch tracking for gradient scaling. - - # TODO: Update naming once Megatron is updated to not be aware of mini batch sizes. - """ +def test_micro_batches_accumulated_initialized(): + """Test that _micro_batches_accumulated is initialized to 0 in worker __init__.""" # Create minimal worker instances for testing class TestPolicyWorker(PolicyWorkerBase): @@ -206,73 +199,35 @@ def backload_to_gpu(self, non_blocking=True): def _forward_micro_batch(self, micro_batch): pass - def create_policy_worker_with_config(dp_size): - """Helper to create policy worker with specific config.""" - cfg = get_default_config() - cfg.trainer.algorithm.policy_loss_type = "regular" - - worker = TestPolicyWorker( - cfg=cfg, - world_size=dp_size, - rank=0, - local_rank=0, - master_addr="localhost", - master_port=12345, - sequence_parallel_size=1, - ) - - # Mock mesh_rank - worker.mesh_rank = MeshRank(dp=0, sp=0, tp=0, pp=0, world_size=dp_size, dp_size=dp_size, pp_size=1) - - return worker - - def create_critic_worker_with_config(dp_size): - """Helper to create critic worker with specific config.""" - cfg = get_default_config() - - worker = TestCriticWorker( - cfg=cfg, - world_size=dp_size, - rank=0, - local_rank=0, - master_addr="localhost", - master_port=12345, - sequence_parallel_size=1, - ) - - # Mock mesh_rank - worker.mesh_rank = MeshRank(dp=0, sp=0, tp=0, pp=0, world_size=dp_size, dp_size=dp_size, pp_size=1) - - return worker - - # Test Case 1: PolicyWorker initializes _micro_batches_accumulated - policy_worker = create_policy_worker_with_config(dp_size=4) - policy_worker._normalize_mini_batch_size() + cfg = get_default_config() + cfg.trainer.algorithm.policy_loss_type = "regular" + # PolicyWorker has _micro_batches_accumulated initialized at construction + policy_worker = TestPolicyWorker( + cfg=cfg, + world_size=4, + rank=0, + local_rank=0, + master_addr="localhost", + master_port=12345, + sequence_parallel_size=1, + ) assert hasattr(policy_worker, "_micro_batches_accumulated") assert policy_worker._micro_batches_accumulated == 0 - # Test Case 2: CriticWorker initializes _micro_batches_accumulated - critic_worker = create_critic_worker_with_config(dp_size=4) - critic_worker._normalize_mini_batch_size() - + # CriticWorker has _micro_batches_accumulated initialized at construction + critic_worker = TestCriticWorker( + cfg=cfg, + world_size=4, + rank=0, + local_rank=0, + master_addr="localhost", + master_port=12345, + sequence_parallel_size=1, + ) assert hasattr(critic_worker, "_micro_batches_accumulated") assert critic_worker._micro_batches_accumulated == 0 - # Test Case 3: Single GPU (dp_size=1) for PolicyWorker - policy_worker = create_policy_worker_with_config(dp_size=1) - policy_worker._normalize_mini_batch_size() - - assert hasattr(policy_worker, "_micro_batches_accumulated") - assert policy_worker._micro_batches_accumulated == 0 - - # Test Case 4: Error case - mesh_rank not initialized - policy_worker_no_mesh = create_policy_worker_with_config(dp_size=4) - policy_worker_no_mesh.mesh_rank = None - - with pytest.raises(RuntimeError, match="mesh_rank must be initialized"): - policy_worker_no_mesh._normalize_mini_batch_size() - def test_validate_batch_sizes(): """Test the validate_batch_sizes function with various configurations to trigger all error cases.""" @@ -503,7 +458,7 @@ def create_test_worker(worker_class): # Test PolicyWorkerBase policy_worker = create_test_worker(PolicyWorkerBase) - # Initialize _micro_batches_accumulated (normally done in _normalize_mini_batch_size) + # Reset _micro_batches_accumulated (initialized in __init__, reset here for test isolation) policy_worker._micro_batches_accumulated = 0 # Mock _forward_backward_micro to track calls @@ -541,7 +496,7 @@ def mock_policy_forward_backward_micro(experience): # Test CriticWorkerBase with same pattern critic_worker = create_test_worker(CriticWorkerBase) - # Initialize _micro_batches_accumulated (normally done in _normalize_mini_batch_size) + # Reset _micro_batches_accumulated (initialized in __init__, reset here for test isolation) critic_worker._micro_batches_accumulated = 0 # Mock _forward_backward_micro for critic From 6cbb2ea428e1797d5546d68d0e3838f6c563c6c4 Mon Sep 17 00:00:00 2001 From: Eric Tang Date: Thu, 22 Jan 2026 00:43:53 +0000 Subject: [PATCH 5/5] x --- skyrl-train/tests/cpu/test_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/skyrl-train/tests/cpu/test_trainer.py b/skyrl-train/tests/cpu/test_trainer.py index d45a9efca..5a85e6b00 100644 --- a/skyrl-train/tests/cpu/test_trainer.py +++ b/skyrl-train/tests/cpu/test_trainer.py @@ -10,7 +10,6 @@ from jaxtyping import Float, Integer from pytest import approx from skyrl_train.config.utils import get_default_config -from skyrl_train.distributed.dispatch import MeshRank from skyrl_train.trainer import RayPPOTrainer from skyrl_train.training_batch import TrainingInputBatch from skyrl_train.utils.utils import validate_batch_sizes