diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 826b51160..f89740011 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -14,6 +14,7 @@ NVIDIA Model Optimizer Changelog (Linux) - Add support for parallel draft heads in Eagle speculative decoding. - Add support to enable custom emulated quantization backend. See :meth:`register_quant_backend `` for more details. See an example in ``tests/unit/torch/quantization/test_custom_backend.py``. - Add ``examples/llm_qad`` for QAD training with Megatron-LM. +- Add support for ``params`` constraint based automatic neural architecture search in Minitron pruning (``mcore_minitron``) as an alternative to manual pruning using ``export_config``. See `examples/pruning/README.md `_ for more details on its usage. **Deprecations** @@ -80,7 +81,7 @@ NVIDIA Model Optimizer Changelog (Linux) **Documentation** -- Add general guidelines for Minitron pruning and distillation. See `examples/pruning/README.md `_ for more details. +- Add general guidelines for Minitron pruning and distillation. See `pruning guidelines `_ for more details. - Added example for exporting QLoRA checkpoint for vLLM deployment. Refer to `examples/llm_qat/README.md `_ for more details 0.37 (2025-10-08) diff --git a/modelopt/torch/nas/plugins/megatron.py b/modelopt/torch/nas/plugins/megatron.py index a25012709..5814009d0 100644 --- a/modelopt/torch/nas/plugins/megatron.py +++ b/modelopt/torch/nas/plugins/megatron.py @@ -27,8 +27,6 @@ from megatron.core.models.gpt import GPTModel from megatron.core.parallel_state import ( get_data_parallel_group, - get_pipeline_model_parallel_group, - get_tensor_model_parallel_group, is_pipeline_first_stage, is_pipeline_last_stage, ) @@ -54,13 +52,8 @@ from modelopt.torch.opt.searcher import ConstraintsDict from modelopt.torch.trace import Symbol from modelopt.torch.utils import distributed as dist -from modelopt.torch.utils import ( - get_module_device, - make_divisible, - param_num_from_forward, - print_rank_0, - random, -) +from modelopt.torch.utils import make_divisible, print_rank_0, random +from modelopt.torch.utils.plugins import param_num_megatron from ..algorithms import ( MODULE_TYPE_TO_CONSTRAINTS_FUNC, @@ -1045,7 +1038,6 @@ def modify( *, hidden_size_divisor: int = 1, ffn_hidden_size_divisor: int = 1, - mamba_num_heads_divisor: int = 1, mamba_head_dim_divisor: int = 1, num_moe_experts_divisor: int = 1, ): @@ -1054,7 +1046,6 @@ def modify( Args: hidden_size_divisor: The divisor of the hidden_size. ffn_hidden_size_divisor: The divisor of the mlp ffn_hidden_size. - mamba_num_heads_divisor: The divisor of the mamba num_heads. mamba_head_dim_divisor: The divisor of the mamba head_dim. num_moe_experts_divisor: The divisor of the number of MoE experts. """ @@ -1065,7 +1056,6 @@ def modify( for layer in self.decoder.layers: layer.modify( ffn_hidden_size_divisor=ffn_hidden_size_divisor, - mamba_num_heads_divisor=mamba_num_heads_divisor, mamba_head_dim_divisor=mamba_head_dim_divisor, num_moe_experts_divisor=num_moe_experts_divisor, ) @@ -1142,11 +1132,7 @@ def constraint_eval_funcs(self) -> dict[str, ConstraintEvalFunc]: def _get_params(self, _: ConstraintsRes | None = None) -> float: """Get number of model parameters from forward pass.""" - params = param_num_from_forward(self.model, args=self.dummy_input, unit=1.0) - reduced_params = torch.Tensor([params]).to(device=get_module_device(self.model)) - torch.distributed.all_reduce(reduced_params, group=get_pipeline_model_parallel_group()) - torch.distributed.all_reduce(reduced_params, group=get_tensor_model_parallel_group()) - return reduced_params.item() + return param_num_megatron(self.model, from_forward=True, args=self.dummy_input) def _get_flops(self, _: ConstraintsRes | None = None) -> float: """Get inference FLOPs.""" diff --git a/modelopt/torch/opt/searcher.py b/modelopt/torch/opt/searcher.py index 5eb2e134e..342903abe 100644 --- a/modelopt/torch/opt/searcher.py +++ b/modelopt/torch/opt/searcher.py @@ -35,7 +35,7 @@ import torch.nn as nn from modelopt.torch.utils import distributed as dist -from modelopt.torch.utils import no_stdout, run_forward_loop +from modelopt.torch.utils import no_stdout, print_rank_0, run_forward_loop LimitsTuple = tuple[float, float] ConstraintsDict = dict[str, str | float | dict | None] @@ -212,6 +212,7 @@ def construct_forward_loop( return None def forward_loop_with_silence_check(m: nn.Module) -> None: + print_rank_0("Running forward loop...") with no_stdout() if silent else nullcontext(): if data_loader is not None: run_forward_loop( diff --git a/modelopt/torch/prune/__init__.py b/modelopt/torch/prune/__init__.py index aac5f7e87..847b22e9d 100644 --- a/modelopt/torch/prune/__init__.py +++ b/modelopt/torch/prune/__init__.py @@ -19,8 +19,6 @@ simplifies the overall workflow to accommodate for the simpler nature of pruning algorithms. """ -# nas is a required - so let's check if it's available -import modelopt.torch.nas from modelopt.torch.utils import import_plugin from . import fastnas, gradnas, plugins diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index db6769b7b..40d5d608b 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -24,9 +24,9 @@ Actual dynamic module implementations are at :mod:`modelopt.torch.nas.plugins.megatron`. """ -import copy from collections.abc import Callable from functools import partial +from itertools import product from typing import Any from warnings import warn @@ -43,6 +43,7 @@ reduce_from_tensor_model_parallel_region, ) from pydantic import create_model +from tqdm import tqdm from modelopt.torch.nas.conversion import NASModeRegistry from modelopt.torch.nas.plugins.megatron import ( @@ -57,7 +58,7 @@ _DynamicTransformerLayer, ) from modelopt.torch.nas.registry import DMRegistry -from modelopt.torch.nas.utils import get_subnet_config, sort_parameters +from modelopt.torch.nas.utils import get_subnet_config, sample, sample_and_reset, sort_parameters from modelopt.torch.opt.config import ModeloptBaseConfig, get_kwargs_for_create_model_with_rules from modelopt.torch.opt.conversion import ApplyModeError from modelopt.torch.opt.dynamic import DynamicModule, DynamicSpace @@ -71,6 +72,7 @@ from modelopt.torch.opt.utils import named_hparams from modelopt.torch.utils import distributed as dist from modelopt.torch.utils import get_module_device, print_rank_0 +from modelopt.torch.utils.plugins import param_num_megatron from ..pruning import PruneModeRegistry @@ -172,6 +174,7 @@ class MCoreMinitronSearcher(BaseSearcher): activations_per_rank: list[dict[str, torch.Tensor]] layer_scores: dict[int, torch.Tensor] + top_k_candidates_per_constraint: dict[float, list[tuple[dict, float]]] @property def default_search_config(self) -> SearchConfig: @@ -181,12 +184,20 @@ def default_search_config(self) -> SearchConfig: "max_iter_data_loader": 1024, "skip_sorting": False, "scores_path": None, + # Additional search config for parameter-based pruning + "max_width_pruning": 0.5, # Maximum fraction per width hyperparameter to prune + "max_depth_pruning": 0.25, # Maximum fraction per depth hyperparameter to prune + "top_k": 10, # Number of candidates to consider for score_func validation } @property def default_state_dict(self) -> SearchStateDict: """Return default state dict for importance scores and activations from forward loop.""" - return {"activations_per_rank": [], "layer_scores": {}} + return { + "activations_per_rank": [], + "layer_scores": {}, + "top_k_candidates_per_constraint": {}, + } def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig: """Sanitize the search config dict.""" @@ -200,32 +211,41 @@ def before_search(self) -> None: super().before_search() # Check that the constraint is valid - assert self.constraints.keys() == {"export_config"}, ( - "Only `export_config` constraint is supported for pruning!" - ) - - self.constraints["export_config"] = copy.deepcopy(self.constraints["export_config"]) - export_config = self.constraints["export_config"] - if "num_query_groups" in export_config: - warn("num_query_groups is no longer supported (since 0.41)! It will be ignored.") - if export_config["num_query_groups"] != self.model.config.num_query_groups: # type: ignore[index] - raise ValueError(f"num_query_groups must be {self.model.config.num_query_groups}!") - export_config.pop("num_query_groups") # type: ignore[union-attr] - assert isinstance(export_config, dict) # to keep mypy happy - assert export_config.keys() <= SUPPORTED_HPARAMS, ( - f"Only {SUPPORTED_HPARAMS} are supported for pruning! Received: {export_config.keys()}" - ) + assert len(self.constraints) == 1 and next(iter(self.constraints.keys())) in { + "export_config", + "params", + }, "Only `export_config` or `params` constraint is supported!" + + if "export_config" in self.constraints: + export_config = self.constraints["export_config"] + assert isinstance(export_config, dict) # to keep mypy happy + if "num_query_groups" in export_config: + warn("num_query_groups is no longer supported (since 0.41)! It will be ignored.") + if export_config["num_query_groups"] != self.model.config.num_query_groups: + raise ValueError( + f"num_query_groups must be {self.model.config.num_query_groups}!" + ) + export_config.pop("num_query_groups") + assert export_config.keys() <= SUPPORTED_HPARAMS, ( + f"Only {SUPPORTED_HPARAMS} are supported for pruning! Received: {export_config=}" + ) - # Only sort the parameters that are to be pruned - # If a user only prunes depth, we should not sort width parameters - self.hps_to_sort = SUPPORTED_HPARAMS & export_config.keys() + # Only sort the parameters that are to be pruned + # If a user only prunes depth, we should not sort width parameters + self.hps_to_sort = set(export_config.keys()) + else: + assert isinstance(self.constraints["params"], float), "params must be a float!" + assert self.has_score, "score_func (e.g. MMLU) is required for parameter-based pruning!" + export_config = None + # Sort all parameters for parameter-based pruning + self.hps_to_sort = SUPPORTED_HPARAMS for n, hp in named_hparams(self.model, unique=True): hp_name = n.split(".")[-1] if hp.is_configurable: # Make sure configurable hparams are the ones with right names else implementation needs to be fixed! assert hp_name in SUPPORTED_HPARAMS, f"[ImplError] Invalid hparam {hp_name}!" - if hp_name in export_config: + if export_config is not None and hp_name in export_config: assert export_config[hp_name] in hp.choices, ( f"Invalid choice {export_config[hp_name]} for {n}! Available choices: {hp.choices}" ) @@ -243,10 +263,8 @@ def run_search(self) -> None: registry = ImportanceEstimatorRegistry(unwrapped_model) if self.layer_scores and self.activations_per_rank: # Available from checkpoint - print_rank_0("Loading activations and scores per rank from checkpoint...") registry.set_activations_and_layer_scores(self.activations_per_rank, self.layer_scores) elif not self.config["skip_sorting"]: - print_rank_0("Running forward loop...") assert self.forward_loop is not None is_training = self.model.training self.model.eval() @@ -265,8 +283,17 @@ def run_search(self) -> None: else: sort_parameters(self.model, self.hps_to_sort, verbose=True) + if "params" in self.constraints: + export_config = self.search_best_arch_by_params( + max_params=self.constraints["params"], # type: ignore[arg-type] + max_width_pruning=self.config["max_width_pruning"], + max_depth_pruning=self.config["max_depth_pruning"], + top_k=self.config["top_k"], + ) + else: + export_config = self.constraints["export_config"] + # Prune homogeneously - export_config = self.constraints["export_config"] assert isinstance(export_config, dict) # to keep mypy happy for n, hp in named_hparams(self.model, configurable=True): hp_name = n.split(".")[-1] @@ -281,20 +308,150 @@ def run_search(self) -> None: layers_to_drop = [layer for layer, _ in sorted_layers[num_layers_hp.active :]] # type: ignore[misc] drop_mcore_language_model_layers(self.model, layers_to_drop=layers_to_drop) + # Update model config with pruned architecture # kv_channels can be None so we need to save original from original hidden_size and num_attention_heads - model_cfg = self.model.config - orig_kv_channels = getattr(model_cfg, "kv_channels") + orig_kv_channels = self.model.config.kv_channels if orig_kv_channels is None: - orig_kv_channels = getattr(model_cfg, "hidden_size") // getattr( - model_cfg, "num_attention_heads" + orig_kv_channels = ( + self.model.config.hidden_size // self.model.config.num_attention_heads ) - setattr(model_cfg, "kv_channels", orig_kv_channels) - for n in SUPPORTED_HPARAMS: - if n in export_config: - setattr(model_cfg, n, export_config[n]) + self.model.config.kv_channels = orig_kv_channels + for hp_name, hp_value in export_config.items(): + setattr(self.model.config, hp_name, hp_value) registry.cleanup() + def search_best_arch_by_params( + self, + max_params: float, + max_width_pruning: float = 0.5, + max_depth_pruning: float = 0.25, + top_k: int = 10, + ) -> dict: + """Search for the best architecture based on the given parameters constraints. + + We perform a grid-search over the search space to find subnets (homogeneous) fitting the constraints. + Top-k candidates (sorted by param count) are then validated using the score_func (e.g. MMLU) + and the best subnet is returned. + + Args: + max_params: Maximum number of parameters for the pruned model. + max_width_pruning: Maximum fraction per width hyperparameter to prune (default: 0.5). + Only top (1 - max_width_pruning) choices will be considered. + max_depth_pruning: Maximum fraction per depth hyperparameter to prune (default: 0.25). + Only top (1 - max_depth_pruning) choices will be considered. + top_k: Number of candidates to consider for score_func validation. + + Returns: + export_config: Dictionary mapping hyperparameter names to their pruned values. + """ + print_rank_0( + f"\nSearching for the best pruned architecture under {max_params / 1e9:.2f}B params constraints" + ) + + # 1. Find available search space choices (across all PP ranks) + hp_choices = {} + for n, hp in named_hparams(self.model, configurable=True): + hp_name = n.split(".")[-1] + hp_choices[hp_name] = hp.choices + all_pp_search_spaces = [None] * get_pipeline_model_parallel_world_size() + torch.distributed.all_gather_object( + all_pp_search_spaces, hp_choices, group=get_pipeline_model_parallel_group() + ) + hp_choices = {k: v for d in all_pp_search_spaces for k, v in d.items()} # type: ignore[attr-defined] + + # 2. Perform grid-search over the search space to find subnets fitting the constraints + if max_params not in self.top_k_candidates_per_constraint: + search_space_configs = MCoreMinitronSearcher._generate_search_space_combos( + hp_choices, # type: ignore[arg-type] + max_width_pruning, + max_depth_pruning, + ) + sample(self.model, sample_func=max) # reset for sanity + selected: list[tuple[dict, float]] = [] + for config in tqdm( + search_space_configs, + desc=f"Finding top {top_k} candidates fitting the constraints...", + disable=not dist.is_master(), + ): + # Convert search space config to fnmatch pattern and sample function + # Use partial to bind each value at creation time (avoid late-binding closure issue) + sample_func = { + f"*.{k}": partial(lambda val, choices: val, v) for k, v in config.items() + } + with sample_and_reset(self.model, sample_func=sample_func): # type: ignore[arg-type] + candidate_params = param_num_megatron(self.model) + if candidate_params <= max_params: + selected.append((config, candidate_params)) + assert len(selected) > 0, "No subnets found fitting the constraints!" + self.top_k_candidates_per_constraint[max_params] = sorted( + selected, key=lambda x: x[1], reverse=True + )[:top_k] + self.save_search_checkpoint(verbose=True) + else: + print_rank_0(f"Using top {top_k} candidates from checkpoint") + top_k_candidates = self.top_k_candidates_per_constraint[max_params] + + # 3. Validate top-k candidates using the score_func and return the best subnet + # TODO: update this + best = top_k_candidates[0][0] + + return best + + @staticmethod + def _generate_search_space_combos( + search_space: dict[str, list], + max_width_pruning: float = 0.5, + max_depth_pruning: float = 0.25, + ) -> list[dict[str, Any]]: + """Generate all possible combinations of hyperparameters from the search space. + + Args: + search_space: Dictionary mapping hyperparameter names to their possible sorted choices. + Example: {"hidden_size": [1024, 2048, 3072, 4096], "num_layers": [1, 2, ..., 31, 32]} + max_width_pruning: Maximum fraction of width hyperparameters to prune (default: 0.5). + Only top (1 - max_width_pruning) choices will be considered. + max_depth_pruning: Maximum fraction of depth hyperparameters to prune (default: 0.25). + Only top (1 - max_depth_pruning) choices will be considered. + + Returns: + List of configuration dictionaries, where each dictionary maps hyperparameter + names to their chosen values. Example: + [ + {"hidden_size": 1024, "num_layers": 1}, + {"hidden_size": 1024, "num_layers": 2}, + ... + {"hidden_size": 4096, "num_layers": 32}, + ] + """ + print_rank_0( + f"\nOnly considering atmost {(max_width_pruning * 100):.0f}% for width and " + f"{max_depth_pruning * 100:.0f}% for depth pruning hparams" + ) + + filtered_ss = { + k: sorted(v)[int((1 - max_depth_pruning) * len(v)) :] + if k == "num_layers" + else sorted(v)[int((1 - max_width_pruning) * len(v)) :] + for k, v in search_space.items() + } + + ss_size = 1 + for k, v in filtered_ss.items(): + print_rank_0(f"\tSearch space for {k}: {v}") + ss_size *= len(v) + print_rank_0(f"\tTotal search space in consideration: {ss_size}\n") + + hparam_names = list(filtered_ss.keys()) + hparam_choices_lists = [filtered_ss[name] for name in hparam_names] + + search_space_combos = [ + dict(zip(hparam_names, choices)) for choices in product(*hparam_choices_lists) + ] + assert len(search_space_combos) == ss_size + + return search_space_combos + MCoreMinitronConfig: type[ModeloptBaseConfig] = create_model( "MCoreMinitronConfig", @@ -302,17 +459,17 @@ def run_search(self) -> None: registry=DMRegistry, default_rules={ "megatron.core.models.gpt.GPTModel": { - "hidden_size_divisor": 64, - "ffn_hidden_size_divisor": 64, - "num_moe_experts_divisor": 1, + "hidden_size_divisor": 256, + "ffn_hidden_size_divisor": 256, + "num_moe_experts_divisor": 8, }, **( { "megatron.core.models.mamba.MambaModel": { - "hidden_size_divisor": 64, - "ffn_hidden_size_divisor": 64, - "mamba_head_dim_divisor": 4, - "num_moe_experts_divisor": 1, + "hidden_size_divisor": 256, + "ffn_hidden_size_divisor": 256, + "mamba_head_dim_divisor": 8, + "num_moe_experts_divisor": 8, } } if HAS_MAMBA @@ -325,9 +482,7 @@ def run_search(self) -> None: def get_mcore_minitron_config( - channel_divisor: int = 64, - mamba_head_dim_divisor: int = 4, - num_moe_experts_divisor: int = 1, + channel_divisor: int = 256, mamba_head_dim_divisor: int = 8, num_moe_experts_divisor: int = 8 ) -> ModeloptBaseConfig: """Get a MCoreMinitronConfig with the given channel divisor instead of default.""" config = MCoreMinitronConfig() @@ -562,6 +717,7 @@ def set_activations_and_layer_scores( activations_per_rank: List of dicts from module name to activations. Should match PP size. layer_scores: Dict from layer_number (1-indexed) to score. """ + print_rank_0("Loading activations and scores per rank from checkpoint...") rank = get_pipeline_model_parallel_rank() pp_size = get_pipeline_model_parallel_world_size() assert len(activations_per_rank) == pp_size, ( diff --git a/modelopt/torch/utils/network.py b/modelopt/torch/utils/network.py index 1940295c3..e18c85c3b 100644 --- a/modelopt/torch/utils/network.py +++ b/modelopt/torch/utils/network.py @@ -142,7 +142,7 @@ def param_num_from_forward( Returns: The number of parameters from the model's forward pass in the given unit. - This can helpful for dynamic modules, where the state dict might contain extra parameters that + This can helpful for MoE or dynamic modules, where the state dict might contain extra parameters that is not actively used in the model, e.g., because of a DynamicModule that is deactivated for the forward pass. We circumvent this issue by just counting parameters of modules that appear in a forward pass. diff --git a/modelopt/torch/utils/plugins/__init__.py b/modelopt/torch/utils/plugins/__init__.py index 517c59914..ac1053aa2 100644 --- a/modelopt/torch/utils/plugins/__init__.py +++ b/modelopt/torch/utils/plugins/__init__.py @@ -23,5 +23,8 @@ with import_plugin("megatron_mmlu"): from .megatron_mmlu import * +with import_plugin("megatron_model"): + from .megatron_model import * + with import_plugin("megatron_preprocess_data"): from .megatron_preprocess_data import * diff --git a/modelopt/torch/utils/plugins/megatron_model.py b/modelopt/torch/utils/plugins/megatron_model.py new file mode 100644 index 000000000..5ea2a7236 --- /dev/null +++ b/modelopt/torch/utils/plugins/megatron_model.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""General utilities for Megatron models.""" + +from typing import Any + +import torch +from megatron.core.parallel_state import ( + get_pipeline_model_parallel_group, + get_tensor_model_parallel_group, +) +from megatron.core.transformer.module import MegatronModule + +from ..network import param_num_from_forward + +__all__ = ["param_num_megatron"] + + +def param_num_megatron( + model: MegatronModule, *, from_forward: bool = False, args: Any = None +) -> float: + """Get the number of parameters in the model (reduced across TP and PP ranks). + + Args: + model: The Megatron model. + from_forward: To get the number of params from a forward pass instead of directly counting the params. + This can helpful for MoE or dynamic modules, where the state dict might contain extra parameters that + is not actively used in the model, e.g., because of a DynamicModule that is deactivated for the + forward pass. We circumvent this issue by just counting parameters of modules that appear in a + forward pass. + args: The arguments to pass to the forward pass. Only used if from_forward is True. + + Returns: + The number of parameters in the model (reduced across TP and PP ranks). + """ + if from_forward: + assert args is not None, "args must be provided if from_forward is True" + params = int(param_num_from_forward(model, args, unit=1.0)) + else: + params = sum(p.numel() for p in model.parameters()) + reduced_params = torch.Tensor([params]).to(device=next(model.parameters()).device) + torch.distributed.all_reduce(reduced_params, group=get_pipeline_model_parallel_group()) + torch.distributed.all_reduce(reduced_params, group=get_tensor_model_parallel_group()) + return reduced_params.item() diff --git a/tests/_test_utils/torch/nas_prune/minitron_common.py b/tests/_test_utils/torch/nas_prune/minitron_common.py index 856edd38c..97b12a4ca 100644 --- a/tests/_test_utils/torch/nas_prune/minitron_common.py +++ b/tests/_test_utils/torch/nas_prune/minitron_common.py @@ -19,7 +19,16 @@ def prune_minitron(model, export_config, config, channel_divisor=64): return mtp.prune( model, - mode=[("mcore_minitron", mtp.mcore_minitron.get_mcore_minitron_config(channel_divisor))], + mode=[ + ( + "mcore_minitron", + mtp.mcore_minitron.get_mcore_minitron_config( + channel_divisor=channel_divisor, + mamba_head_dim_divisor=4, + num_moe_experts_divisor=1, + ), + ) + ], constraints={"export_config": export_config}, dummy_input=None, # Not used config=config, diff --git a/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py b/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py index 2679d3090..16b45cdb0 100644 --- a/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py +++ b/tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py @@ -25,6 +25,7 @@ from _test_utils.torch.megatron.models import get_mcore_gpt_model from _test_utils.torch.megatron.utils import run_mcore_inference from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.parallel_state import destroy_model_parallel from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.mlp import MLP from megatron.core.transformer.transformer_layer import TransformerLayer @@ -32,6 +33,7 @@ import modelopt.torch.nas as mtn from modelopt.torch.nas.modules import DynamicModuleList from modelopt.torch.nas.plugins.megatron import ( + NumAttentionHeadsHp, _DynamicColumnParallelLinear, _DynamicEmbedding, _DynamicLanguageModelEmbedding, @@ -81,7 +83,7 @@ def _test_gpt_search_space( normalization=normalization, ).cuda() - model = mtn.convert(model, [("mcore_minitron", get_mcore_minitron_config(channel_divisor))]) + mtn.convert(model, [("mcore_minitron", get_mcore_minitron_config(channel_divisor))]) assert isinstance(model, _DynamicMCoreLanguageModel) for m in model.modules(): @@ -153,6 +155,74 @@ def test_expand_head_indices(): assert expand_head_indices(heads, hidden_size_per_head).tolist() == [2, 3, 6, 7, 4, 5, 0, 1] +def test_gpt_self_attention_head_sorting(distributed_setup_size_1): + model = get_mcore_gpt_model( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + initialize_megatron=True, + num_layers=1, + hidden_size=16, + num_attention_heads=8, + num_query_groups=2, + ffn_hidden_size=16, + activation_func="squared_relu", + ).cuda() + + model = mtn.convert(model, "mcore_minitron") + + self_attn = model.decoder.layers[0].self_attention + assert isinstance(self_attn, _DynamicSelfAttention) + assert isinstance(self_attn.linear_qkv, _DynamicQKVColumnParallelLinear) + assert isinstance(self_attn.linear_proj, _DynamicProjRowParallelLinear) + + hp_num_attention_heads = self_attn.get_hparam("num_attention_heads") + assert isinstance(hp_num_attention_heads, NumAttentionHeadsHp) + + # Choices are multiples of num_query_groups (2): [2, 4, 6, 8] + assert hp_num_attention_heads.choices == [2, 4, 6, 8] + assert hp_num_attention_heads._num_query_groups == 2 + + # Set importance and slice order + # Importance per head (group-aware): [2.2, 0.1, 1.1, 2.1, 3.0, 2.0, 0.0, 1.0] + # Group 0 (heads 0-3): [2.2, 0.1, 1.1, 2.1] → sorted: [0, 3, 2, 1] + # Group 1 (heads 4-7): [3.0, 2.0, 0.0, 1.0] → sorted: [4, 5, 7, 6] + # Global ranking (group-aware, flattened): [0, 3, 2, 1, 4, 5, 7, 6] + hp_num_attention_heads._get_importance = lambda: torch.tensor( + [2.2, 0.1, 1.1, 2.1, 3.0, 2.0, 0.0, 1.0] + ) + # _estimate_head_ranking returns ranking as 1D tensor + expected_ranking = torch.tensor([0, 3, 2, 1, 4, 5, 7, 6]) + hp_num_attention_heads.enforce_order(expected_ranking) + + assert hp_num_attention_heads.active_slice.tolist() == [0, 3, 2, 1, 4, 5, 7, 6] + + # check if we get correct selection of sorted + pruned heads after setting active values + hp_num_attention_heads.active = 4 # top 2 heads per group (2 groups * 2 heads = 4 total) + + # Expected: Top 2 heads from each group: [0, 3] from group 0, [4, 5] from group 1 + expected_q_heads = [0, 3, 4, 5] + # In QKV layout (4 heads/group → 6 QKV heads/group): + # Group 0: Q=[0, 3], K=4, V=5 → QKV indices [0, 3, 4, 5] + # Group 1: Q=[4, 5], K=10, V=11 → QKV indices [6, 7, 10, 11] + expected_qkv_heads = [0, 3, 4, 5, 6, 7, 10, 11] + + assert ( + self_attn.linear_qkv._get_output_size_indices().tolist() + == expand_head_indices( + torch.LongTensor(expected_qkv_heads), model.config.kv_channels + ).tolist() + ) + assert ( + self_attn.linear_proj._get_input_size_indices().tolist() + == expand_head_indices( + torch.LongTensor(expected_q_heads), model.config.kv_channels + ).tolist() + ) + + # Clean up since this is not a spawned process + destroy_model_parallel() + + def _test_gpt_moe_search_space(rank, size): channel_divisor = 4 @@ -183,7 +253,10 @@ def _test_gpt_moe_search_space(rank, size): moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, ).cuda() - model = mtn.convert(model, [("mcore_minitron", get_mcore_minitron_config(channel_divisor))]) + mtn.convert( + model, + [("mcore_minitron", get_mcore_minitron_config(channel_divisor, num_moe_experts_divisor=1))], + ) moe = model.decoder.layers[0].mlp assert isinstance(moe, _DynamicMoELayer) diff --git a/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py b/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py index 430b5e261..6a1bc7a8a 100644 --- a/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py +++ b/tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py @@ -51,7 +51,7 @@ def _test_mamba_search_space(rank, size): mamba_head_dim_divisor = 4 num_layers = size - hybrid_override_pattern = "M" * size + hybrid_override_pattern = "M" * size # all layers are Mamba layers hidden_size = channel_divisor * 4 mamba_state_dim = channel_divisor mamba_head_dim = mamba_head_dim_divisor * 2 @@ -75,7 +75,10 @@ def _test_mamba_search_space(rank, size): ).cuda() mamba_num_heads = model.decoder.layers[0].mixer.nheads - model = mtn.convert(model, [("mcore_minitron", get_mcore_minitron_config(channel_divisor))]) + mtn.convert( + model, + [("mcore_minitron", get_mcore_minitron_config(channel_divisor, mamba_head_dim_divisor))], + ) assert isinstance(model, _DynamicMCoreLanguageModel) if is_pipeline_first_stage(): diff --git a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py index a0d4877bb..2f1eae76b 100644 --- a/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py +++ b/tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py @@ -29,20 +29,13 @@ ) from _test_utils.torch.misc import compare_outputs, set_seed from _test_utils.torch.nas_prune.minitron_common import prune_minitron -from megatron.core.parallel_state import destroy_model_parallel from megatron.core.transformer.identity_op import IdentityOp import modelopt.torch.nas as mtn from modelopt.torch.nas.conversion import export_searchspace -from modelopt.torch.nas.plugins.megatron import ( - NumAttentionHeadsHp, - _DynamicProjRowParallelLinear, - _DynamicQKVColumnParallelLinear, - _DynamicSelfAttention, - expand_head_indices, -) from modelopt.torch.prune.plugins.mcore_minitron import ( ImportanceEstimatorRegistry, + MCoreMinitronSearcher, _convert_model_to_dynamic_space, get_mcore_minitron_config, ) @@ -124,74 +117,6 @@ def test_mcore_gpt_parameter_sorting(activation_func): ) -def test_mcore_gpt_self_attention_head_sorting(distributed_setup_size_1): - model = get_mcore_gpt_model( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - initialize_megatron=True, - num_layers=1, - hidden_size=16, - num_attention_heads=8, - num_query_groups=2, - ffn_hidden_size=16, - activation_func="squared_relu", - ).cuda() - - model = mtn.convert(model, "mcore_minitron") - - self_attn = model.decoder.layers[0].self_attention - assert isinstance(self_attn, _DynamicSelfAttention) - assert isinstance(self_attn.linear_qkv, _DynamicQKVColumnParallelLinear) - assert isinstance(self_attn.linear_proj, _DynamicProjRowParallelLinear) - - hp_num_attention_heads = self_attn.get_hparam("num_attention_heads") - assert isinstance(hp_num_attention_heads, NumAttentionHeadsHp) - - # Choices are multiples of num_query_groups (2): [2, 4, 6, 8] - assert hp_num_attention_heads.choices == [2, 4, 6, 8] - assert hp_num_attention_heads._num_query_groups == 2 - - # Set importance and slice order - # Importance per head (group-aware): [2.2, 0.1, 1.1, 2.1, 3.0, 2.0, 0.0, 1.0] - # Group 0 (heads 0-3): [2.2, 0.1, 1.1, 2.1] → sorted: [0, 3, 2, 1] - # Group 1 (heads 4-7): [3.0, 2.0, 0.0, 1.0] → sorted: [4, 5, 7, 6] - # Global ranking (group-aware, flattened): [0, 3, 2, 1, 4, 5, 7, 6] - hp_num_attention_heads._get_importance = lambda: torch.tensor( - [2.2, 0.1, 1.1, 2.1, 3.0, 2.0, 0.0, 1.0] - ) - # _estimate_head_ranking returns ranking as 1D tensor - expected_ranking = torch.tensor([0, 3, 2, 1, 4, 5, 7, 6]) - hp_num_attention_heads.enforce_order(expected_ranking) - - assert hp_num_attention_heads.active_slice.tolist() == [0, 3, 2, 1, 4, 5, 7, 6] - - # check if we get correct selection of sorted + pruned heads after setting active values - hp_num_attention_heads.active = 4 # top 2 heads per group (2 groups * 2 heads = 4 total) - - # Expected: Top 2 heads from each group: [0, 3] from group 0, [4, 5] from group 1 - expected_q_heads = [0, 3, 4, 5] - # In QKV layout (4 heads/group → 6 QKV heads/group): - # Group 0: Q=[0, 3], K=4, V=5 → QKV indices [0, 3, 4, 5] - # Group 1: Q=[4, 5], K=10, V=11 → QKV indices [6, 7, 10, 11] - expected_qkv_heads = [0, 3, 4, 5, 6, 7, 10, 11] - - assert ( - self_attn.linear_qkv._get_output_size_indices().tolist() - == expand_head_indices( - torch.LongTensor(expected_qkv_heads), model.config.kv_channels - ).tolist() - ) - assert ( - self_attn.linear_proj._get_input_size_indices().tolist() - == expand_head_indices( - torch.LongTensor(expected_q_heads), model.config.kv_channels - ).tolist() - ) - - # Clean up since this is not a spawned process - destroy_model_parallel() - - def _test_mcore_gpt_pruning( num_attention_heads, num_query_groups, @@ -430,7 +355,7 @@ def _test_mcore_gpt_moe_parameter_sorting(rank, size): model.eval() dynamic_space = _convert_model_to_dynamic_space( - model, get_mcore_minitron_config(channel_divisor) + model, get_mcore_minitron_config(channel_divisor=channel_divisor, num_moe_experts_divisor=1) ) registry = ImportanceEstimatorRegistry(model) # register imp estimators and forward hooks @@ -570,3 +495,29 @@ def test_mcore_gpt_pruning_moe(tmp_path): job=partial(_test_mcore_gpt_pruning_moe, tmp_path / "minitron_scores.pth"), backend="nccl", ) + + +def test_generate_search_space_combos(): + ss = { + "hidden_size": [32, 64, 96, 128, 160], + "num_attention_heads": [8, 16, 24, 32], + "num_layers": [1, 2, 3, 4, 5, 6, 7, 8], + } + ss_combos = MCoreMinitronSearcher._generate_search_space_combos( + ss, max_width_pruning=0.5, max_depth_pruning=0.25 + ) + assert len(ss_combos) == 3 * 2 * 2 + assert ss_combos == [ + {"hidden_size": 96, "num_attention_heads": 24, "num_layers": 7}, + {"hidden_size": 96, "num_attention_heads": 24, "num_layers": 8}, + {"hidden_size": 96, "num_attention_heads": 32, "num_layers": 7}, + {"hidden_size": 96, "num_attention_heads": 32, "num_layers": 8}, + {"hidden_size": 128, "num_attention_heads": 24, "num_layers": 7}, + {"hidden_size": 128, "num_attention_heads": 24, "num_layers": 8}, + {"hidden_size": 128, "num_attention_heads": 32, "num_layers": 7}, + {"hidden_size": 128, "num_attention_heads": 32, "num_layers": 8}, + {"hidden_size": 160, "num_attention_heads": 24, "num_layers": 7}, + {"hidden_size": 160, "num_attention_heads": 24, "num_layers": 8}, + {"hidden_size": 160, "num_attention_heads": 32, "num_layers": 7}, + {"hidden_size": 160, "num_attention_heads": 32, "num_layers": 8}, + ] diff --git a/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py b/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py index d6fa9400b..a7f036bbb 100644 --- a/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py +++ b/tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py @@ -78,7 +78,7 @@ def _test_mcore_mamba_parameter_sorting(rank, size): model.eval() dynamic_space = _convert_model_to_dynamic_space( - model, get_mcore_minitron_config(channel_divisor) + model, get_mcore_minitron_config(channel_divisor=channel_divisor, mamba_head_dim_divisor=4) ) registry = ImportanceEstimatorRegistry(model) # register imp estimators and forward hooks