Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 7 additions & 16 deletions skyrl-train/skyrl_train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,26 +1074,17 @@ def train_critic_and_policy(self, data: TrainingInputBatch):
"""
Run the training step for the policy and critic models.

For Megatron strategy: uses ppo_train (training loop inside worker)
For FSDP strategy: uses forward_backward + optim_step (training loop in trainer)
Uses forward_backward + optim_step for both FSDP and Megatron strategies.
"""
data.metadata["global_step"] = self.global_step
critic_status = None

if self.cfg.trainer.strategy == "megatron":
# Megatron: training loop inside worker via ppo_train
if self.has_critic:
with Timer("critic_train", self.all_timings):
critic_status = self.dispatch.ppo_train("critic", data)
with Timer("policy_train", self.all_timings):
policy_status = self.dispatch.ppo_train("policy", data)
else:
# FSDP: training loop in trainer via forward_backward + optim_step
if self.has_critic:
with Timer("critic_train", self.all_timings):
critic_status = self._execute_training_step("critic", data)
with Timer("policy_train", self.all_timings):
policy_status = self._execute_training_step("policy", data)
# Unified training interface for both FSDP and Megatron
if self.has_critic:
with Timer("critic_train", self.all_timings):
critic_status = self._execute_training_step("critic", data)
with Timer("policy_train", self.all_timings):
policy_status = self._execute_training_step("policy", data)

# Update metrics
if critic_status is not None:
Expand Down
6 changes: 0 additions & 6 deletions skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,6 @@ def init_model(self, model_path, num_training_steps: int = None):

self._is_lora = self.cfg.trainer.policy.model.lora.rank > 0

# Update per-gpu mini batch size based on device mesh
self._normalize_mini_batch_size()

model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
init_context = get_init_weight_context_manager(
use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh
Expand Down Expand Up @@ -276,9 +273,6 @@ def init_model(self, model_path, num_training_steps: int = None):
strategy.setup_distributed()
self.strategy = strategy

# Update per-gpu mini batch size based on device mesh
self._normalize_mini_batch_size()

model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
init_context = get_init_weight_context_manager(
use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.strategy.device_mesh
Expand Down
223 changes: 90 additions & 133 deletions skyrl-train/skyrl_train/workers/megatron/megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from datetime import timedelta
from typing import List, Dict, Any, Optional
from collections import defaultdict
from tqdm import tqdm
from omegaconf import OmegaConf

from megatron.bridge import AutoBridge
Expand All @@ -29,7 +28,7 @@
from skyrl_train.distributed.megatron.megatron_utils import print_model_size, broadcast_object_across_pp_ranks
from skyrl_train.utils.utils import update_model_config, str_to_torch_dtype
from skyrl_train.env_vars import SKYRL_WORKER_NCCL_TIMEOUT_IN_S
from skyrl_train.training_batch import TrainingOutputBatch
from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch
from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics
from skyrl_train.workers.worker import (
PolicyWorkerBase,
Expand Down Expand Up @@ -431,20 +430,6 @@ def _broadcast_no_grad(*args, **kwargs):
pp_size=mpu.get_pipeline_model_parallel_world_size(),
)

def _normalize_mini_batch_size(self):
"""
Override to set Megatron-specific batch size attributes.

Megatron's ppo_train method needs policy_mini_batch_size_per_gpu to compute
how many micro batches fit in a mini batch for gradient accumulation.
"""
super()._normalize_mini_batch_size() # Sets _micro_batches_accumulated

# Megatron-specific: compute mini batch size per GPU for ppo_train
n_samples = self.cfg.generator.n_samples_per_prompt
dp_size = self.mesh_rank.dp_size
self.policy_mini_batch_size_per_gpu = (self.cfg.trainer.policy_mini_batch_size * n_samples) // dp_size

def init_model(self, model_path, num_training_steps: int = 1e9):
"""
Initialize the model, optimizer, and scheduler for the policy worker.
Expand Down Expand Up @@ -487,8 +472,6 @@ def init_model(self, model_path, num_training_steps: int = 1e9):
)
self.optimizer = get_megatron_optimizer(self.actor_module, optim_config)

self._normalize_mini_batch_size()

# create scheduler
self.scheduler = get_megatron_optimizer_param_scheduler(
optimizer=self.optimizer,
Expand Down Expand Up @@ -519,136 +502,110 @@ def init_model(self, model_path, num_training_steps: int = 1e9):

self.empty_cuda_cache = self.cfg.trainer.policy.megatron_config.empty_cuda_cache

def ppo_train(self, train_data) -> "TrainingOutputBatch":
def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]:
"""
Overrides `PolicyWorkerBase.ppo_train` for megatron.
Perform forward and backward passes for a batch, handling micro-batching internally.

Since we want megatron to handle gradient accumulation over micro batches, we directly pass mini batches into the
worker MegatronModelWrapper.forward_backward_mini_batch method.
"""
dataloader = BatchIterator(
train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False
)
The batch is split into micro batches based on micro_train_batch_size_per_gpu.
Megatron Core's forward_backward_func handles gradient accumulation internally.

micro_batches_per_mini_batch = (
self.policy_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu
)
Args:
data: TrainingInputBatch (already DP-sharded by WorkerDispatch/MeshDispatch)

status_list = []
Returns:
Aggregated metrics dict across all micro batches
"""
self.model.train()
for chunk in self.actor_module:
# if use distributed optimizer, zero grad buffer will be handled by optimizer
chunk.zero_grad_buffer()

micro_batch_size = self.cfg.trainer.micro_train_batch_size_per_gpu
all_metrics = defaultdict(list)
policy_update_steps = 0

if self.profiler is not None:
self.profiler.start()
# Move data to GPU
data.to(torch.cuda.current_device())

# Build micro-batch dicts expected by forward_backward_mini_batch
micro_buffer = []
for experience in BatchIterator(data, micro_batch_size, drop_last=False):
sequences = experience.sequences
attention_mask = experience.attention_mask
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)

for epoch in range(self.cfg.trainer.update_epochs_per_batch):
self.optimizer.zero_grad()
pbar = tqdm(
dataloader,
desc=f"Policy Train epoch [{epoch + 1}/{self.cfg.trainer.update_epochs_per_batch}]",
disable=not self.strategy.is_rank_0(),
micro_buffer.append(
{
"sequences": sequences,
"attention_mask": attention_mask,
"position_ids": position_ids,
"num_actions": experience.num_actions,
"old_action_log_probs": experience.action_log_probs,
"base_action_log_probs": experience.base_action_log_probs,
"advantages": experience.advantages,
"loss_mask": experience.loss_mask,
"rollout_action_logprobs": experience.rollout_logprobs,
}
)

# TODO: Convert this into 2 loops for minibatches and microbatches.
micro_buffer = []
for local_step, experience in enumerate(pbar):
# BatchIterator now yields Experience objects directly
experience.to_device(torch.cuda.current_device())
sequences = experience.sequences
attention_mask = experience.attention_mask
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)

micro_buffer.append(
{
"sequences": sequences,
"attention_mask": attention_mask,
"position_ids": position_ids,
"num_actions": experience.num_actions,
"old_action_log_probs": experience.action_log_probs,
"base_action_log_probs": experience.base_action_log_probs,
"advantages": experience.advantages,
"loss_mask": experience.loss_mask,
"rollout_action_logprobs": experience.rollout_logprobs,
}
)
if not micro_buffer:
return {}

if len(micro_buffer) == micro_batches_per_mini_batch:
# run mini-batch forward-backward and then one optimizer step
self.model.train()
for chunk in self.actor_module:
# if use distributed optimizer, zero grad buffer will be handled by optimizer
chunk.zero_grad_buffer()
seq_len = micro_buffer[0]["sequences"].shape[1]
micro_bsz = micro_buffer[0]["sequences"].shape[0]

metrics_list = self.model.forward_backward_mini_batch(
micro_batches=micro_buffer,
seq_len=seq_len,
micro_batch_size=micro_bsz,
temperature=self.cfg.generator.sampling_params.temperature,
)
seq_len = micro_buffer[0]["sequences"].shape[1]
micro_bsz = micro_buffer[0]["sequences"].shape[0]

if self.empty_cuda_cache:
torch.cuda.empty_cache()

grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor")

# within a DP group, metrics are already the same across all workers - we then just all reduce across
# the whole world size to get the metrics for the global micro batch
for i, metrics in enumerate(metrics_list):
status = {
"final_loss": metrics["final_loss"],
"policy_loss": metrics["policy_loss"],
"policy_lr": self.optimizer.param_groups[0]["lr"],
"ppo_clip_ratio": metrics["ppo_clip_ratio"],
"policy_entropy": metrics["policy_entropy"],
}
if self.cfg.trainer.algorithm.use_kl_loss:
status["policy_kl"] = metrics["policy_kl"]

# Attach grad norm only for the last micro in the mini-batch
if i == len(metrics_list) - 1 and grad_norm is not None:
status["raw_grad_norm"] = grad_norm

# attach response_length
status["response_length"] = micro_buffer[i]["num_actions"]

status = self.strategy.all_reduce(status)
status_list.append(status)
for k, v in status.items():
all_metrics[k].append(v)

short_status = {
"pg": status_list[-1]["policy_loss"],
"glen": status_list[-1]["response_length"],
"policy_lr": status_list[-1]["policy_lr"],
"ent": status_list[-1]["policy_entropy"],
}
if "raw_grad_norm" in status_list[-1]:
short_status["grad_norm"] = status_list[-1]["raw_grad_norm"]
pbar.set_postfix(short_status)

policy_update_steps += 1
micro_buffer = []

# drop any trailing micros that don't fill a mini-batch (keep behavior consistent)
micro_buffer = []
metrics_list = self.model.forward_backward_mini_batch(
micro_batches=micro_buffer,
seq_len=seq_len,
micro_batch_size=micro_bsz,
temperature=self.cfg.generator.sampling_params.temperature,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The temperature parameter is typically associated with sampling in generation tasks. While it might be used internally by MegatronModelWrapper.forward_backward_mini_batch for entropy calculation or similar, its presence in a training method's signature can be confusing. Consider adding a comment to clarify its specific role in the training forward/backward pass, or if possible, refactor MegatronModelWrapper to only expose parameters relevant to training loss calculation in this context.

)

torch.distributed.barrier()
if self.profiler is not None:
self.profiler.stop_and_save()
self.profiler.stop_trace()
if self.empty_cuda_cache:
torch.cuda.empty_cache()

# not needed beyond status logging
all_metrics.pop("response_length", None)
# Track number of micro-batches for metrics
self._micro_batches_accumulated += len(micro_buffer)

status_mean = reduce_metrics(all_metrics)
status_mean["policy_update_steps"] = policy_update_steps
# Aggregate metrics across micro-batches
for metrics in metrics_list:
for k, v in metrics.items():
all_metrics[k].append(v)

output = TrainingOutputBatch()
output.metadata = {"train_status": status_mean}
return output
# Reduce and all-reduce metrics
status = reduce_metrics(dict(all_metrics))
status["policy_lr"] = self.optimizer.param_groups[0]["lr"]
status = self.strategy.all_reduce(status)

return status

def optim_step(self) -> Optional[float]:
"""
Perform optimizer step.

Note: Unlike FSDP workers, Megatron doesn't need manual gradient scaling here
because Megatron Core's forward_backward_func handles loss scaling internally.

Returns:
The gradient norm (before scaling, after clipping), or None if unavailable.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need the gradient scaling for grad accumulation here i think? looked into it briefly and it seemed doable

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tyler-griggs given offline discussion of moving to do loss sums, maybe we don't need to do any scaling here anymore actually? We will need to figure out how the megatron internal gradient accumulation works and whether we can enforce that it also only does sums.

grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor")

# Reset counter for next accumulation cycle
self._micro_batches_accumulated = 0

if grad_norm is not None:
grad_norm = grad_norm.detach().cpu().item() if hasattr(grad_norm, "item") else grad_norm
return grad_norm

def get_lr(self) -> float:
"""
Get current learning rate from optimizer.

Override base class method because Megatron's OptimizerParamScheduler
doesn't have get_last_lr() like PyTorch schedulers.
"""
return self.optimizer.param_groups[0]["lr"]

async def broadcast_to_inference_engines(self, inference_engine_client):
use_prefix_cache = self.cfg.generator.enable_prefix_caching
Expand Down
34 changes: 0 additions & 34 deletions skyrl-train/skyrl_train/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,23 +639,6 @@ def __init__(self, **kwargs):
self.record_memory: bool = False
self.mesh_rank: MeshRank = None
self.policy_loss_fn: Callable = PolicyLossRegistry.get(self.cfg.trainer.algorithm.policy_loss_type)

def _normalize_mini_batch_size(self):
"""
Initialize micro batch tracking for gradient accumulation.

The worker no longer needs to know mini batch size - it processes whatever
batch it receives, breaking it into micro batches. Gradient scaling happens
at optim_step time based on how many micro batches were accumulated.

TODO: Rename to _init_gradient_accumulation_state once Megatron no longer
requires mini-batch normalization in its override. The name is kept for
backwards compatibility with Megatron which still does actual normalization.
"""
if not hasattr(self, "mesh_rank") or self.mesh_rank is None:
raise RuntimeError("mesh_rank must be initialized before calling _normalize_mini_batch_size()")

# Track micro batches for gradient scaling at optim_step
self._micro_batches_accumulated = 0

def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]:
Expand Down Expand Up @@ -886,23 +869,6 @@ def __init__(self, **kwargs):
self.record_memory: bool = False
self.mesh_rank: MeshRank = None
self.critic_loss_fn: Callable = ppo_critic_loss

def _normalize_mini_batch_size(self):
"""
Initialize micro batch tracking for gradient accumulation.

The worker no longer needs to know mini batch size - it processes whatever
batch it receives, breaking it into micro batches. Gradient scaling happens
at optim_step time based on how many micro batches were accumulated.

TODO: Rename to _init_gradient_accumulation_state once Megatron no longer
requires mini-batch normalization in its override. The name is kept for
backwards compatibility with Megatron which still does actual normalization.
"""
if not hasattr(self, "mesh_rank") or self.mesh_rank is None:
raise RuntimeError("mesh_rank must be initialized before calling _normalize_mini_batch_size()")

# Track micro batches for gradient scaling at optim_step
self._micro_batches_accumulated = 0

def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]:
Expand Down
10 changes: 0 additions & 10 deletions skyrl-train/skyrl_train/workers/worker_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,6 @@ def optim_step(self, model: str) -> Optional[float]:
self._save_memory_snapshot(model, "optim_step")
return grad_norms[0]

# TODO(tgriggs): Remove this when Megatron supports forward_backward and optim_step.
def ppo_train(self, model: str, data: TrainingInputBatch) -> Dict[str, float]:
"""Run full PPO training loop (for Megatron)."""
self._ensure_on_gpu(model, need_optimizer=True, need_model=True)

refs = self._actor_groups[model].async_run_ray_method("mesh", "ppo_train", data)
statuses = ray.get(refs)

return statuses[0].metadata["train_status"]

def _save_memory_snapshot(self, model: str, tag: str) -> None:
"""Save memory snapshot on workers."""
ray.get(
Expand Down
Loading