-
Notifications
You must be signed in to change notification settings - Fork 234
Unify Megatron and FSDP training interfaces with forward_backward + optim_step #901
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
97d757a
13ce698
7adcb3d
879ac21
6cbb2ea
c64d913
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,7 +9,6 @@ | |
| from datetime import timedelta | ||
| from typing import List, Dict, Any, Optional | ||
| from collections import defaultdict | ||
| from tqdm import tqdm | ||
| from omegaconf import OmegaConf | ||
|
|
||
| from megatron.bridge import AutoBridge | ||
|
|
@@ -29,7 +28,7 @@ | |
| from skyrl_train.distributed.megatron.megatron_utils import print_model_size, broadcast_object_across_pp_ranks | ||
| from skyrl_train.utils.utils import update_model_config, str_to_torch_dtype | ||
| from skyrl_train.env_vars import SKYRL_WORKER_NCCL_TIMEOUT_IN_S | ||
| from skyrl_train.training_batch import TrainingOutputBatch | ||
| from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch | ||
| from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics | ||
| from skyrl_train.workers.worker import ( | ||
| PolicyWorkerBase, | ||
|
|
@@ -431,20 +430,6 @@ def _broadcast_no_grad(*args, **kwargs): | |
| pp_size=mpu.get_pipeline_model_parallel_world_size(), | ||
| ) | ||
|
|
||
| def _normalize_mini_batch_size(self): | ||
| """ | ||
| Override to set Megatron-specific batch size attributes. | ||
|
|
||
| Megatron's ppo_train method needs policy_mini_batch_size_per_gpu to compute | ||
| how many micro batches fit in a mini batch for gradient accumulation. | ||
| """ | ||
| super()._normalize_mini_batch_size() # Sets _micro_batches_accumulated | ||
|
|
||
| # Megatron-specific: compute mini batch size per GPU for ppo_train | ||
| n_samples = self.cfg.generator.n_samples_per_prompt | ||
| dp_size = self.mesh_rank.dp_size | ||
| self.policy_mini_batch_size_per_gpu = (self.cfg.trainer.policy_mini_batch_size * n_samples) // dp_size | ||
|
|
||
| def init_model(self, model_path, num_training_steps: int = 1e9): | ||
| """ | ||
| Initialize the model, optimizer, and scheduler for the policy worker. | ||
|
|
@@ -487,8 +472,6 @@ def init_model(self, model_path, num_training_steps: int = 1e9): | |
| ) | ||
| self.optimizer = get_megatron_optimizer(self.actor_module, optim_config) | ||
|
|
||
| self._normalize_mini_batch_size() | ||
|
|
||
| # create scheduler | ||
| self.scheduler = get_megatron_optimizer_param_scheduler( | ||
| optimizer=self.optimizer, | ||
|
|
@@ -519,136 +502,110 @@ def init_model(self, model_path, num_training_steps: int = 1e9): | |
|
|
||
| self.empty_cuda_cache = self.cfg.trainer.policy.megatron_config.empty_cuda_cache | ||
|
|
||
| def ppo_train(self, train_data) -> "TrainingOutputBatch": | ||
| def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]: | ||
| """ | ||
| Overrides `PolicyWorkerBase.ppo_train` for megatron. | ||
| Perform forward and backward passes for a batch, handling micro-batching internally. | ||
|
|
||
| Since we want megatron to handle gradient accumulation over micro batches, we directly pass mini batches into the | ||
| worker MegatronModelWrapper.forward_backward_mini_batch method. | ||
| """ | ||
| dataloader = BatchIterator( | ||
| train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False | ||
| ) | ||
| The batch is split into micro batches based on micro_train_batch_size_per_gpu. | ||
| Megatron Core's forward_backward_func handles gradient accumulation internally. | ||
|
|
||
| micro_batches_per_mini_batch = ( | ||
| self.policy_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu | ||
| ) | ||
| Args: | ||
| data: TrainingInputBatch (already DP-sharded by WorkerDispatch/MeshDispatch) | ||
|
|
||
| status_list = [] | ||
| Returns: | ||
| Aggregated metrics dict across all micro batches | ||
| """ | ||
| self.model.train() | ||
| for chunk in self.actor_module: | ||
| # if use distributed optimizer, zero grad buffer will be handled by optimizer | ||
| chunk.zero_grad_buffer() | ||
|
|
||
| micro_batch_size = self.cfg.trainer.micro_train_batch_size_per_gpu | ||
| all_metrics = defaultdict(list) | ||
| policy_update_steps = 0 | ||
|
|
||
| if self.profiler is not None: | ||
| self.profiler.start() | ||
| # Move data to GPU | ||
| data.to(torch.cuda.current_device()) | ||
|
|
||
| # Build micro-batch dicts expected by forward_backward_mini_batch | ||
| micro_buffer = [] | ||
| for experience in BatchIterator(data, micro_batch_size, drop_last=False): | ||
| sequences = experience.sequences | ||
| attention_mask = experience.attention_mask | ||
| position_ids = attention_mask.long().cumsum(-1) - 1 | ||
| position_ids.masked_fill_(attention_mask == 0, 0) | ||
|
|
||
| for epoch in range(self.cfg.trainer.update_epochs_per_batch): | ||
| self.optimizer.zero_grad() | ||
| pbar = tqdm( | ||
| dataloader, | ||
| desc=f"Policy Train epoch [{epoch + 1}/{self.cfg.trainer.update_epochs_per_batch}]", | ||
| disable=not self.strategy.is_rank_0(), | ||
| micro_buffer.append( | ||
| { | ||
| "sequences": sequences, | ||
| "attention_mask": attention_mask, | ||
| "position_ids": position_ids, | ||
| "num_actions": experience.num_actions, | ||
| "old_action_log_probs": experience.action_log_probs, | ||
| "base_action_log_probs": experience.base_action_log_probs, | ||
| "advantages": experience.advantages, | ||
| "loss_mask": experience.loss_mask, | ||
| "rollout_action_logprobs": experience.rollout_logprobs, | ||
| } | ||
| ) | ||
|
|
||
| # TODO: Convert this into 2 loops for minibatches and microbatches. | ||
| micro_buffer = [] | ||
| for local_step, experience in enumerate(pbar): | ||
| # BatchIterator now yields Experience objects directly | ||
| experience.to_device(torch.cuda.current_device()) | ||
| sequences = experience.sequences | ||
| attention_mask = experience.attention_mask | ||
| position_ids = attention_mask.long().cumsum(-1) - 1 | ||
| position_ids.masked_fill_(attention_mask == 0, 0) | ||
|
|
||
| micro_buffer.append( | ||
| { | ||
| "sequences": sequences, | ||
| "attention_mask": attention_mask, | ||
| "position_ids": position_ids, | ||
| "num_actions": experience.num_actions, | ||
| "old_action_log_probs": experience.action_log_probs, | ||
| "base_action_log_probs": experience.base_action_log_probs, | ||
| "advantages": experience.advantages, | ||
| "loss_mask": experience.loss_mask, | ||
| "rollout_action_logprobs": experience.rollout_logprobs, | ||
| } | ||
| ) | ||
| if not micro_buffer: | ||
| return {} | ||
|
|
||
| if len(micro_buffer) == micro_batches_per_mini_batch: | ||
| # run mini-batch forward-backward and then one optimizer step | ||
| self.model.train() | ||
| for chunk in self.actor_module: | ||
| # if use distributed optimizer, zero grad buffer will be handled by optimizer | ||
| chunk.zero_grad_buffer() | ||
| seq_len = micro_buffer[0]["sequences"].shape[1] | ||
| micro_bsz = micro_buffer[0]["sequences"].shape[0] | ||
|
|
||
| metrics_list = self.model.forward_backward_mini_batch( | ||
| micro_batches=micro_buffer, | ||
| seq_len=seq_len, | ||
| micro_batch_size=micro_bsz, | ||
| temperature=self.cfg.generator.sampling_params.temperature, | ||
| ) | ||
| seq_len = micro_buffer[0]["sequences"].shape[1] | ||
| micro_bsz = micro_buffer[0]["sequences"].shape[0] | ||
|
|
||
| if self.empty_cuda_cache: | ||
| torch.cuda.empty_cache() | ||
|
|
||
| grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor") | ||
|
|
||
| # within a DP group, metrics are already the same across all workers - we then just all reduce across | ||
| # the whole world size to get the metrics for the global micro batch | ||
| for i, metrics in enumerate(metrics_list): | ||
| status = { | ||
| "final_loss": metrics["final_loss"], | ||
| "policy_loss": metrics["policy_loss"], | ||
| "policy_lr": self.optimizer.param_groups[0]["lr"], | ||
| "ppo_clip_ratio": metrics["ppo_clip_ratio"], | ||
| "policy_entropy": metrics["policy_entropy"], | ||
| } | ||
| if self.cfg.trainer.algorithm.use_kl_loss: | ||
| status["policy_kl"] = metrics["policy_kl"] | ||
|
|
||
| # Attach grad norm only for the last micro in the mini-batch | ||
| if i == len(metrics_list) - 1 and grad_norm is not None: | ||
| status["raw_grad_norm"] = grad_norm | ||
|
|
||
| # attach response_length | ||
| status["response_length"] = micro_buffer[i]["num_actions"] | ||
|
|
||
| status = self.strategy.all_reduce(status) | ||
| status_list.append(status) | ||
| for k, v in status.items(): | ||
| all_metrics[k].append(v) | ||
|
|
||
| short_status = { | ||
| "pg": status_list[-1]["policy_loss"], | ||
| "glen": status_list[-1]["response_length"], | ||
| "policy_lr": status_list[-1]["policy_lr"], | ||
| "ent": status_list[-1]["policy_entropy"], | ||
| } | ||
| if "raw_grad_norm" in status_list[-1]: | ||
| short_status["grad_norm"] = status_list[-1]["raw_grad_norm"] | ||
| pbar.set_postfix(short_status) | ||
|
|
||
| policy_update_steps += 1 | ||
| micro_buffer = [] | ||
|
|
||
| # drop any trailing micros that don't fill a mini-batch (keep behavior consistent) | ||
| micro_buffer = [] | ||
| metrics_list = self.model.forward_backward_mini_batch( | ||
| micro_batches=micro_buffer, | ||
| seq_len=seq_len, | ||
| micro_batch_size=micro_bsz, | ||
| temperature=self.cfg.generator.sampling_params.temperature, | ||
| ) | ||
|
|
||
| torch.distributed.barrier() | ||
| if self.profiler is not None: | ||
| self.profiler.stop_and_save() | ||
| self.profiler.stop_trace() | ||
| if self.empty_cuda_cache: | ||
| torch.cuda.empty_cache() | ||
|
|
||
| # not needed beyond status logging | ||
| all_metrics.pop("response_length", None) | ||
| # Track number of micro-batches for metrics | ||
| self._micro_batches_accumulated += len(micro_buffer) | ||
|
|
||
| status_mean = reduce_metrics(all_metrics) | ||
| status_mean["policy_update_steps"] = policy_update_steps | ||
| # Aggregate metrics across micro-batches | ||
| for metrics in metrics_list: | ||
| for k, v in metrics.items(): | ||
| all_metrics[k].append(v) | ||
|
|
||
| output = TrainingOutputBatch() | ||
| output.metadata = {"train_status": status_mean} | ||
| return output | ||
| # Reduce and all-reduce metrics | ||
| status = reduce_metrics(dict(all_metrics)) | ||
| status["policy_lr"] = self.optimizer.param_groups[0]["lr"] | ||
| status = self.strategy.all_reduce(status) | ||
|
|
||
| return status | ||
|
|
||
| def optim_step(self) -> Optional[float]: | ||
| """ | ||
| Perform optimizer step. | ||
|
|
||
| Note: Unlike FSDP workers, Megatron doesn't need manual gradient scaling here | ||
| because Megatron Core's forward_backward_func handles loss scaling internally. | ||
|
|
||
| Returns: | ||
| The gradient norm (before scaling, after clipping), or None if unavailable. | ||
| """ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need the gradient scaling for grad accumulation here i think? looked into it briefly and it seemed doable
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tyler-griggs given offline discussion of moving to do loss sums, maybe we don't need to do any scaling here anymore actually? We will need to figure out how the megatron internal gradient accumulation works and whether we can enforce that it also only does sums. |
||
| grad_norm = self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler, name="actor") | ||
|
|
||
| # Reset counter for next accumulation cycle | ||
| self._micro_batches_accumulated = 0 | ||
|
|
||
| if grad_norm is not None: | ||
| grad_norm = grad_norm.detach().cpu().item() if hasattr(grad_norm, "item") else grad_norm | ||
| return grad_norm | ||
|
|
||
| def get_lr(self) -> float: | ||
| """ | ||
| Get current learning rate from optimizer. | ||
|
|
||
| Override base class method because Megatron's OptimizerParamScheduler | ||
| doesn't have get_last_lr() like PyTorch schedulers. | ||
| """ | ||
| return self.optimizer.param_groups[0]["lr"] | ||
|
|
||
| async def broadcast_to_inference_engines(self, inference_engine_client): | ||
| use_prefix_cache = self.cfg.generator.enable_prefix_caching | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
temperatureparameter is typically associated with sampling in generation tasks. While it might be used internally byMegatronModelWrapper.forward_backward_mini_batchfor entropy calculation or similar, its presence in a training method's signature can be confusing. Consider adding a comment to clarify its specific role in the training forward/backward pass, or if possible, refactorMegatronModelWrapperto only expose parameters relevant to training loss calculation in this context.