diff --git a/modelopt/onnx/quantization/autotune/common.py b/modelopt/onnx/quantization/autotune/common.py new file mode 100644 index 000000000..0c1cc8537 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/common.py @@ -0,0 +1,688 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common data structures and types for the QDQ Autotuner. + +This module provides the foundational classes used throughout the autotuner: + +**Exceptions:** +- Region-related: RegionError +- Autotuner-related: AutotunerError, AutotunerNotInitializedError, InvalidSchemeError + +**Region Hierarchy:** +- Region: Hierarchical subgraph representation with parent/child relationships +- RegionType: Enumeration for LEAF, COMPOSITE, and ROOT regions + +**Q/DQ Insertion Specifications:** +- InsertionScheme: Collection of insertion points with performance metrics + +**Scheme Management:** +- PatternSchemes: Multiple insertion schemes for a pattern (applies to all matching regions) +- PatternCache: Collection of top schemes for multiple patterns, used as autotuning seeds + +**Configuration:** +- Config: Autotuning parameters and Q/DQ default values +""" + +import hashlib +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Optional + +import onnx_graphsurgeon as gs + +from modelopt.onnx.quantization.autotune.insertion_points import ( + ChildRegionInputInsertionPoint, + NodeInputInsertionPoint, + RegionOutputInsertionPoint, +) + +# Module logger +logger = logging.getLogger(__name__) + + +# Region-related Exceptions +class RegionError(Exception): + """Base exception for region-related errors.""" + + +# Autotuner-related Exceptions +class AutotunerError(Exception): + """Base exception for autotuner-related errors.""" + + +class AutotunerNotInitializedError(AutotunerError): + """Exception raised when autotuner is used without initialization.""" + + +class InvalidSchemeError(AutotunerError): + """Exception raised when an invalid scheme is referenced.""" + + +class RegionType(Enum): + """Region type enumeration for hierarchical graph structure. + + - LEAF: Atomic region containing direct nodes with no child regions + - COMPOSITE: Hierarchical region containing child regions (and optionally direct nodes) + - ROOT: Top-level region encompassing the entire computation graph + """ + + LEAF = "LEAF" + COMPOSITE = "COMPOSITE" + ROOT = "ROOT" + + +class Region: + """Hierarchical subgraph region in an ONNX computation graph. + + A Region represents a cohesive subgraph with well-defined boundaries, supporting: + + **Hierarchical Structure:** + - Parent/child relationships forming a multi-level hierarchy + - LEAF regions contain only direct nodes + - COMPOSITE regions contain child regions (and optionally direct nodes) + - ROOT regions encompass the entire graph + + **Node Management:** + - Direct nodes: Operations directly in this region (not in children) + - Recursive nodes: All operations including those in descendant regions + + **Boundary Tracking:** + - Input tensors: Data entering the region from outside + - Output tensors: Data leaving the region to outside consumers + + **Pattern Matching:** + - Regions with identical structure share the same pattern signature + - Pattern-based optimization applies schemes to all matching regions + + Regions are the fundamental unit for Q/DQ insertion and optimization. + """ + + def __init__(self, region_id: int, level: int, region_type: RegionType): + """Initialize a new region. + + Args: + region_id: Unique identifier within the region hierarchy + level: Hierarchical level (0 = leaf, higher = more composite) + region_type: Type classification (LEAF, COMPOSITE, or ROOT) + """ + self.id = region_id + self.level = level + self.type = region_type + self.parent: Region | None = None + self.children: list[Region] = [] + self.nodes: set[int] = set() + self.inputs: list[str] = [] + self.outputs: list[str] = [] + + # ========================================================================= + # Basic Accessors + # ========================================================================= + + def get_id(self) -> int: + """Get region ID.""" + return self.id + + def set_id(self, region_id: int) -> None: + """Set region ID (for RegionBuilder use).""" + self.id = region_id + + def get_level(self) -> int: + """Get region level in hierarchy.""" + return self.level + + def set_level(self, level: int) -> None: + """Set region level in hierarchy (for RegionBuilder use).""" + self.level = level + + def get_type(self) -> RegionType: + """Get region type.""" + return self.type + + def set_type(self, region_type: RegionType) -> None: + """Set region type (for RegionBuilder use).""" + self.type = region_type + + # ========================================================================= + # Hierarchy Management + # ========================================================================= + + def get_parent(self) -> Optional["Region"]: + """Get parent region.""" + return self.parent + + def set_parent(self, parent: Optional["Region"]) -> None: + """Set parent region.""" + self.parent = parent + + def get_children(self) -> list["Region"]: + """Get all child regions.""" + return self.children + + def remove_child(self, child: "Region") -> bool: + """Remove a child region from this region's children list. + + Args: + child: The child region to remove + + Returns: + True if child was found and removed, False otherwise + """ + child_id = child.get_id() + initial_count = len(self.children) + self.children = [c for c in self.children if c.get_id() != child_id] + removed = len(self.children) < initial_count + + if removed and child.parent and child.parent.get_id() == self.id: + child.set_parent(None) + + return removed + + def add_child(self, child: "Region") -> None: + """Add a child sub-region.""" + # Prevent adding self as child + if child.get_id() == self.id: + logger.warning(f"Cannot add region {self.id} as its own child") + return + + # Prevent creating cycles: check if self is already a descendant of child + if self._is_descendant_of(child): + logger.warning( + f"Cycle detected: region {self.id} is already a descendant of region {child.get_id()}" + ) + return + + # Check if child already has a different parent + if child.parent is not None and child.parent.get_id() != self.id: + old_parent_id = child.parent.get_id() + logger.debug( + f"Re-parenting region {child.get_id()}: moving from parent {old_parent_id} to {self.id}" + ) + # Remove from old parent to maintain tree structure + child.parent.remove_child(child) + + # Check if child is already in children list + if any(c.get_id() == child.get_id() for c in self.children): + logger.debug(f"Region {child.get_id()} already child of {self.id}") + return + + self.children.append(child) + child.set_parent(self) + + def _is_descendant_of(self, potential_ancestor: "Region") -> bool: + """Check if this region is a descendant of potential_ancestor.""" + visited = set() + current = self.parent + while current: + if current.get_id() in visited: + # Already visited, there's a cycle in parents + return False + visited.add(current.get_id()) + if current.get_id() == potential_ancestor.get_id(): + return True + current = current.parent + return False + + # ========================================================================= + # Node Management + # ========================================================================= + + def add_node(self, node_index: int) -> None: + """Add a node index to this region.""" + self.nodes.add(node_index) + + def add_nodes(self, node_indices: list[int]) -> None: + """Add multiple node indices to this region.""" + self.nodes.update(node_indices) + + def get_nodes(self) -> set[int]: + """Get direct node indices in this region only. + + Returns only nodes directly owned by this region, excluding nodes + in child regions. Use get_all_nodes_recursive() for complete coverage. + + Returns: + Set of node indices (absolute positions in the graph) + """ + return self.nodes + + def get_all_nodes_recursive(self, _visited: set[int] | None = None) -> set[int]: + """Get all node indices recursively, including descendants. + + Traverses the entire subtree rooted at this region, collecting nodes + from this region and all child regions recursively. + + Args: + _visited: Internal parameter for cycle detection (do not use) + + Returns: + Set of all node indices in this region and its descendants + """ + if _visited is None: + _visited = set() + + # Detect cycles + if self.id in _visited: + logger.warning(f"Cycle detected in region {self.id} during node traversal") + return set() + + _visited.add(self.id) + all_nodes = set(self.nodes) + for child in self.children: + all_nodes.update(child.get_all_nodes_recursive(_visited)) + return all_nodes + + def contains_node(self, node_index: int) -> bool: + """Check if region contains a specific node (direct only).""" + return node_index in self.nodes + + def contains_node_recursive(self, node_index: int, _visited: set[int] | None = None) -> bool: + """Check if region contains a node recursively.""" + if _visited is None: + _visited = set() + + # Detect cycles + if self.id in _visited: + return False + + _visited.add(self.id) + + if self.contains_node(node_index): + return True + return any(child.contains_node_recursive(node_index, _visited) for child in self.children) + + # ========================================================================= + # Input/Output Management + # ========================================================================= + + def add_input(self, tensor_name: str) -> None: + """Add an input tensor name.""" + if tensor_name not in self.inputs: + self.inputs.append(tensor_name) + + def add_output(self, tensor_name: str) -> None: + """Add an output tensor name.""" + if tensor_name not in self.outputs: + self.outputs.append(tensor_name) + + def get_inputs(self) -> list[str]: + """Get region input tensors.""" + return self.inputs + + def get_outputs(self) -> list[str]: + """Get region output tensors.""" + return self.outputs + + # ========================================================================= + # Size and Query Methods + # ========================================================================= + + def get_size(self) -> int: + """Get the number of direct nodes in this region. + + Returns: + Count of nodes directly in this region (excludes child regions) + """ + return len(self.nodes) + + def get_total_size(self, _visited: set[int] | None = None) -> int: + """Get total node count recursively including all descendants. + + Computes the sum of nodes in this region and all child regions, + providing the total footprint of the region subtree. + + Args: + _visited: Internal parameter for cycle detection (do not use) + + Returns: + Total number of nodes in this region and all descendants + """ + if _visited is None: + _visited = set() + + # Detect cycles + if self.id in _visited: + logger.warning(f"Cycle detected in region {self.id} during size calculation") + return len(self.nodes) + + _visited.add(self.id) + total = len(self.nodes) + for child in self.children: + total += child.get_total_size(_visited) + return total + + # ========================================================================= + # Region Operations + # ========================================================================= + + def merge(self, other: "Region") -> None: + """Merge another region into this one. + + Combines the nodes and children from the other region into this region. + The other region's children become children of this region, updating + their parent references accordingly. + + Args: + other: Region to merge into this one + """ + if not other: + return + # Merge direct nodes + self.nodes.update(other.nodes) + # Merge children (updates their parent references) + for child in other.children: + self.add_child(child) + + # ========================================================================= + # String Representation + # ========================================================================= + + def to_string(self) -> str: + """Print region information for debugging.""" + type_str = self.type.value + return ( + f"Region[id={self.id}, level={self.level}, type={type_str}, " + f"nodes={len(self.nodes)}, children={len(self.children)}, " + f"inputs={len(self.inputs)}, outputs={len(self.outputs)}]" + ) + + def __str__(self) -> str: + return self.to_string() + + def __repr__(self) -> str: + return self.to_string() + + def compute_structural_signature(self, graph: gs.Graph) -> str: + """Compute deterministic structural signature for pattern matching. + + Creates a signature that uniquely identifies the region's topology, + node operations, and hierarchical structure. Regions with identical + signatures can share Q/DQ insertion schemes. + + The signature captures: + - Node operation types and key parameters + - Hierarchical structure (child regions) + - Deterministic ordering (sorted for consistency) + + Args: + graph: The ONNX graph containing the region's nodes + + Returns: + Signature string (e.g., "Conv->BatchNorm->Relu" or "COMPOSITE(...)") + """ + raise NotImplementedError("Not implemented") + + +# ============================================================================= +# Autotuner Q/DQ Insertion Specifications +# ============================================================================= + + +@dataclass +class InsertionScheme: + """Complete Q/DQ insertion specification for a region pattern. + + An InsertionScheme defines a complete Q/DQ configuration for a pattern, + combining both node-level and region-level insertion points. The scheme + is applied to all regions matching the pattern. + + **Scheme Identity:** + - Uniquely identified by the combination of insertion points (computed hash) + - latency_ms is a measured performance metric, not part of identity + - Two schemes with same insertion points but different latencies are considered identical + + **Application:** + - Node insertion points: Q/DQ at node inputs within the pattern + - Region insertion points: Q/DQ at child region boundaries (COMPOSITE only) + - All are resolved to actual configurations for each matching region + + **Performance Tracking:** + - latency_ms: Measured performance (inf = not yet measured) + - error: Whether this scheme encountered an error during measurement + - Used to select the best scheme for each pattern + + **Attributes:** + node_inputs: Q/DQ insertions at node inputs (list of NodeInputInsertionPoint) + child_region_inputs: Q/DQ insertions at child boundaries (list of ChildRegionInputInsertionPoint) + region_outputs: Q/DQ insertions at region outputs (list of RegionOutputInsertionPoint) + latency_ms: Measured latency in milliseconds (inf if not measured) + error: True if scheme measurement failed, False otherwise + profile_timestamp: ISO format timestamp when this scheme was profiled (None if not yet profiled) + """ + + node_inputs: list[NodeInputInsertionPoint] = field(default_factory=list) + child_region_inputs: list[ChildRegionInputInsertionPoint] = field(default_factory=list) + region_outputs: list[RegionOutputInsertionPoint] = field(default_factory=list) + latency_ms: float = float("inf") + error: bool = False + profile_timestamp: str | None = None + + @property + def hash(self) -> str: + """Compute deterministic hash for scheme identity. + + The hash uniquely identifies this scheme configuration based on its + insertion points. Two schemes with identical insertion points produce + the same hash, regardless of their measured latencies. + + **Hash Input:** + - Sorted node_inputs (for deterministic ordering) + - Sorted child_region_inputs (for deterministic ordering) + - Sorted region_outputs (for deterministic ordering) + - latency_ms is EXCLUDED (performance metric, not identity) + + **Use Cases:** + - Detect duplicate schemes before measurement + - Group schemes by configuration + - Efficient scheme comparison + + Returns: + 32-character hexadecimal string (SHA-256 truncated to 128 bits) + """ + # Sort points for deterministic hashing + sorted_nodes = sorted([(pt.node_index, pt.input_index) for pt in self.node_inputs]) + sorted_regions = sorted( + [(pt.region_index, pt.input_index) for pt in self.child_region_inputs] + ) + sorted_region_outputs = sorted( + [(pt.region_index, pt.node_index, pt.output_index) for pt in self.region_outputs] + ) + + # Create hash input string + hash_input = f"{sorted_nodes}|{sorted_regions}|{sorted_region_outputs}" + + # Compute SHA-256 hash (128 bits) + return hashlib.sha256(hash_input.encode("utf-8")).hexdigest()[:32] + + @property + def is_empty(self) -> bool: + """Check if this is a baseline scheme with no Q/DQ insertions. + + Returns: + True if scheme has no node/region insertion points + """ + return ( + len(self.node_inputs) == 0 + and len(self.child_region_inputs) == 0 + and len(self.region_outputs) == 0 + ) + + @property + def has_error(self) -> bool: + """Check if this scheme encountered an error during measurement. + + Returns: + True if scheme has error=True, False otherwise + """ + return self.error + + @property + def is_profiled(self) -> bool: + """Check if this scheme has been profiled (measured). + + A scheme is considered profiled if it has been measured (has non-infinite latency) + or has encountered an error during measurement. + + Returns: + True if scheme has been measured (latency_ms != inf) or has error, + False if scheme is waiting to be profiled (error=False and latency_ms=inf) + """ + return self.error or self.latency_ms != float("inf") + + @property + def num_node_insertions(self) -> int: + """Get count of node-level Q/DQ insertion points. + + Returns: + Number of NodeInputInsertionPoint entries + """ + return len(self.node_inputs) + + @property + def num_region_insertions(self) -> int: + """Get count of region-level Q/DQ insertion points. + + These specify Q/DQ insertions at child region boundaries within + COMPOSITE regions. + + Returns: + Number of ChildRegionInputInsertionPoint entries + """ + return len(self.child_region_inputs) + + @property + def num_region_output_insertions(self) -> int: + """Get count of region output insertion points. + + These specify Q/DQ insertions at outputs from child regions or nodes. + + Returns: + Number of RegionOutputInsertionPoint entries + """ + return len(self.region_outputs) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "latency_ms": self.latency_ms, + "error": self.error, + "profile_timestamp": self.profile_timestamp, + "nodes_insertion_points": [pt.to_dict() for pt in self.node_inputs], + "child_region_inputs": [pt.to_dict() for pt in self.child_region_inputs], + "region_outputs": [pt.to_dict() for pt in self.region_outputs], + "hash": self.hash, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "InsertionScheme": + """Create InsertionScheme from serialized dictionary. + + Reconstructs the insertion scheme from saved data, including node and + region insertion points. The hash is automatically recomputed from all + components to ensure consistency. + + Args: + data: Dictionary containing 'latency_ms', 'nodes_insertion_points', + 'child_region_inputs', and 'region_outputs' keys + + Returns: + Reconstructed InsertionScheme instance + """ + scheme = cls() + scheme.latency_ms = data.get("latency_ms", float("inf")) + scheme.error = data.get("error", False) + scheme.profile_timestamp = data.get("profile_timestamp") + + scheme.node_inputs = [ + NodeInputInsertionPoint.from_dict(pt) for pt in data.get("nodes_insertion_points", []) + ] + scheme.child_region_inputs = [ + ChildRegionInputInsertionPoint.from_dict(pt) + for pt in data.get("child_region_inputs", []) + ] + scheme.region_outputs = [ + RegionOutputInsertionPoint.from_dict(pt) for pt in data.get("region_outputs", []) + ] + + # Note: hash is computed from points, so we don't load it from dict + # This ensures consistency even if stored hash differs + + return scheme + + def distance(self, other: "InsertionScheme") -> int: + """Compute edit distance between this scheme and another scheme. + + The edit distance is the minimum number of add/remove operations needed + to transform this scheme into the other scheme. This is computed as the + symmetric difference between the insertion point sets. + + **Distance Calculation:** + - Counts insertion points in self but not in other (need to be removed) + - Counts insertion points in other but not in self (need to be added) + - Considers all three types of insertion points: + * node_inputs + * child_region_inputs + * region_outputs + + Args: + other: InsertionScheme to compare against + + Returns: + Total edit distance (number of add + remove operations) + + Example: + >>> scheme1 = InsertionScheme( + ... node_inputs=[ + ... NodeInputInsertionPoint(0, 0), + ... NodeInputInsertionPoint(1, 0), + ... ] + ... ) + >>> scheme2 = InsertionScheme( + ... node_inputs=[ + ... NodeInputInsertionPoint(0, 0), + ... NodeInputInsertionPoint(2, 0), + ... ] + ... ) + >>> scheme1.distance(scheme2) # 2 (remove (1,0), add (2,0)) + 2 + """ + # Convert insertion points to sets for efficient set operations + self_nodes = set(self.node_inputs) + other_nodes = set(other.node_inputs) + + self_regions = set(self.child_region_inputs) + other_regions = set(other.child_region_inputs) + + self_region_outputs = set(self.region_outputs) + other_region_outputs = set(other.region_outputs) + + # Compute symmetric difference (elements in either set but not both) + # This gives us the total number of add + remove operations + node_distance = len(self_nodes.symmetric_difference(other_nodes)) + region_distance = len(self_regions.symmetric_difference(other_regions)) + region_output_distance = len(self_region_outputs.symmetric_difference(other_region_outputs)) + + return node_distance + region_distance + region_output_distance + + def __str__(self) -> str: + """String representation for debugging.""" + error_str = ", error=True" if self.error else "" + return ( + f"InsertionScheme(node_insertions={self.num_node_insertions}, " + f"region_insertions={self.num_region_insertions}, " + f"region_output_insertions={self.num_region_output_insertions}, " + f"latency={self.latency_ms:.3f}ms{error_str})" + ) diff --git a/modelopt/onnx/quantization/autotune/insertion_points.py b/modelopt/onnx/quantization/autotune/insertion_points.py new file mode 100644 index 000000000..d0dc7b945 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/insertion_points.py @@ -0,0 +1,897 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Q/DQ Insertion Point Management for ONNX Quantization. + +This module provides data structures and utilities for managing Quantization/Dequantization (Q/DQ) +insertion points in ONNX computational graphs during autotune optimization. It enables pattern-based +Q/DQ insertion that can be reused across multiple matching regions in a model. + +Core Concepts: +-------------- +1. **Pattern-Relative Insertion Points**: Insertion points are defined relative to region patterns + rather than absolute node IDs, enabling scheme reuse across all matching regions. + +2. **Resolution Process**: Pattern-relative indices are resolved to actual tensor names for each + specific region instance, then Q/DQ pairs are inserted at the resolved locations. + +3. **Hierarchical Support**: Supports Q/DQ insertion at multiple levels: + - Node inputs within regions + - Child region boundaries (inputs/outputs) + - Region outputs + +Classes: +-------- +- ResolvedInsertionPoint: Resolved Q/DQ insertion point with actual tensor name +- NodeInputInsertionPoint: Pattern-relative insertion point at node inputs +- ChildRegionInputInsertionPoint: Pattern-relative insertion point at child region inputs +- RegionOutputInsertionPoint: Pattern-relative insertion point at region/node outputs + +Utilities: +---------- +- skip_invalid_insertion_points(): Filter out non-quantizable tensors +- has_quantizable_operations(): Check if region contains major quantizable ops +- resolve_region_io_insertion_points(): Resolve region I/O to actual insertion points +- merge_resolved_insertion_points(): Merge insertion points when all users are quantized + +Constants: +---------- +- BOOL_OPERATIONS: Boolean/comparison operations (not quantizable) +- SHAPE_OPERATIONS: Shape manipulation operations (not quantizable) +- MAJOR_QUANTIZABLE_OPERATIONS: Key operations that benefit from quantization +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import numpy as np +import onnx_graphsurgeon as gs + +if TYPE_CHECKING: + from modelopt.onnx.quantization.autotune.common import Region + +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + +BOOL_OPERATIONS = { + "Not", + "And", + "Or", + "Xor", + "BitwiseAnd", + "BitwiseOr", + "BitwiseXor", + "BitShift", + "IsNaN", + "IsInf", + "Sign", + "Abs", + "Equal", + "Greater", + "GreaterOrEqual", + "Less", + "LessOrEqual", + "Where", + "Max", + "Min", + "Mean", + "Median", + "ArgMax", + "ArgMin", + "ReduceMax", + "ReduceMin", + "ReduceSum", + "ReduceMean", + "All", + "Any", + "Unique", + "NonZero", + "TopK", +} + +SHAPE_OPERATIONS = { + "Cast", + "Ceil", + "Clip", + "Compress", + "Concat", + "ExpandDims", + "Flatten", + "Gather", + "GatherElements", + "GatherND", + "Identity", + "Pad", + "Range", + "Scatter", + "ScatterND", + "Shape", + "Slice", + "Split", + "Squeeze", + "Tile", + "Transpose", + "Unsqueeze", + "View", +} + +MAJOR_QUANTIZABLE_OPERATIONS = { + "Conv", + "ConvTranspose", + "Gemm", + "MatMul", + "AveragePool", + "MaxPool", + "GlobalAveragePool", + "GlobalMaxPool", + "Resize", + "Add", + "Sum", + "Mul", + "Relu", +} + + +@dataclass(frozen=True) +class ResolvedInsertionPoint: + """Resolved Q/DQ insertion point with actual tensor name and optional node context. + + After resolving pattern-relative insertion points, this class represents the + actual location where Q/DQ pairs should be inserted in the graph. + + **Insertion Modes:** + 1. Node-specific insertion (node_index and input_index are set): + - Inserts Q/DQ at a specific input of a specific node + - More precise control over where quantization happens + 2. Tensor-level insertion (node_index and input_index are None): + - Inserts Q/DQ for all users of the tensor + - Used when all consumers of a tensor should be quantized together + + **Attributes:** + - tensor_name: Name of the tensor where Q/DQ should be inserted + - node_index: Absolute graph node index (not pattern-relative), or None for tensor-level insertion + - input_index: Input tensor index of that node, or None for tensor-level insertion + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + tensor_name: str + # Absolute graph node index (or None for tensor-level insertion) + node_index: int | None = None + # Input tensor index of that node (or None) + input_index: int | None = None + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "tensor_name": self.tensor_name, + "node_index": self.node_index, + "input_index": self.input_index, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ResolvedInsertionPoint": + """Create from dictionary.""" + return cls( + tensor_name=data["tensor_name"], + node_index=data["node_index"], + input_index=data.get("input_index"), + ) + + def __str__(self) -> str: + """String representation for debugging.""" + return ( + f"ResolvedInsertionPoint(tensor_name={self.tensor_name}, " + f"node={self.node_index}, input={self.input_index})" + ) + + +@dataclass(frozen=True) +class NodeInputInsertionPoint: + """Pattern-relative Q/DQ insertion point at a node's input. + + Specifies where to insert a Q/DQ pair within a region pattern using + pattern-relative indices rather than absolute node IDs. This enables + insertion scheme reuse across all regions matching the same pattern. + + **Resolution Process:** + 1. Pattern-relative indices (node_index, input_index) are defined once + 2. For each matching region, indices are resolved to actual tensor names + 3. Q/DQ pairs are inserted at the resolved tensor locations + + **Example:** + - NodeInputInsertionPoint(node_index=0, input_index=1) + - Resolves to: the second input (index 1) of the first node (index 0) in the pattern + - Actual tensor name depends on the specific region instance + + **Attributes:** + - node_index: Index of the node within the pattern's sorted node list (0-based) + - input_index: Index of the input tensor for that node (0-based) + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + # Pattern-relative node index + node_index: int + # Input tensor index of that node + input_index: int + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return {"node_index": self.node_index, "input_index": self.input_index} + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "NodeInputInsertionPoint": + """Create from dictionary.""" + return cls(node_index=data["node_index"], input_index=data["input_index"]) + + def __str__(self) -> str: + """String representation for debugging.""" + return f"NodeInputInsertionPoint(node={self.node_index}, input={self.input_index})" + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a node input insertion point to actual tensor names for a matching region. + + Converts pattern-relative node/input indices to absolute node indices and actual + tensor names in the graph. Special handling for Conv/ConvTranspose operations + automatically includes weight quantization when input is quantized. + + Args: + region: The region instance matching this pattern + graph: The ONNX graph containing the nodes + + Returns: + Set of ResolvedInsertionPoint objects with actual tensor names + """ + nodes_list = list(graph.nodes) + node_indices = sorted(region.get_nodes()) + resolved_ips = set() + + # Map from pattern-relative node index to absolute graph node index + assert self.node_index < len(node_indices), "Node index out of range" + actual_node_idx = node_indices[self.node_index] + assert actual_node_idx < len(nodes_list), "Node index out of range" + node = nodes_list[actual_node_idx] + assert self.input_index < len(node.inputs), "Input index out of range" + + # Resolve the input tensor name using input_index + inp = node.inputs[self.input_index] + if hasattr(inp, "name") and inp.name: + ip = ResolvedInsertionPoint( + tensor_name=inp.name, node_index=actual_node_idx, input_index=self.input_index + ) + resolved_ips.add(ip) + + if node.op in ["Conv", "ConvTranspose"]: + assert self.input_index == 0, ( + "Conv and ConvTranspose inputs and weights should be quantized at same time" + ) + assert len(node.inputs) >= 2, "Conv and ConvTranspose should have at least 2 inputs" + inp = node.inputs[1] + if hasattr(inp, "name") and inp.name: + ip = ResolvedInsertionPoint( + tensor_name=inp.name, node_index=actual_node_idx, input_index=1 + ) + resolved_ips.add(ip) + + return resolved_ips + + @staticmethod + def collect_from_region(region: "Region", graph: gs.Graph) -> list["NodeInputInsertionPoint"]: + """Collect all valid node input insertion points from a region. + + Analyzes each node in the region and identifies all valid input tensors + where Q/DQ pairs could be inserted. Filters out invalid insertion points + using skip_invalid_insertion_points(). + + Args: + region: The region to collect insertion points from + graph: The ONNX graph containing the nodes + + Returns: + List of NodeInputInsertionPoint objects representing valid insertion locations + """ + nodes_list = list(graph.nodes) + node_indices = sorted(region.get_nodes()) + + node_input_insertion_points = [] + for local_idx, node_idx in enumerate(node_indices): + assert node_idx < len(nodes_list), "Node index out of range" + node = nodes_list[node_idx] + # Analyze each input of the node + for input_idx, inp in enumerate(node.inputs): + # Skip if tensor doesn't have a valid name + if not (hasattr(inp, "name") and inp.name): + continue + # Skip if insertion point is invalid (wrong dtype, small size, special input, etc.) + if skip_invalid_insertion_points(graph, inp.name, node): + continue + # Create insertion point for valid tensor + ip = NodeInputInsertionPoint( + # Pattern-relative node index + node_index=local_idx, + input_index=input_idx, + ) + node_input_insertion_points.append(ip) + + return node_input_insertion_points + + +@dataclass(frozen=True) +class ChildRegionInputInsertionPoint: + """Pattern-relative Q/DQ insertion point at a child region's input boundary. + + Specifies where to insert Q/DQ pairs at the input boundaries of child regions + within COMPOSITE regions. This allows parent regions to control quantization + at child boundaries, potentially overriding or complementing child region + optimizations. + + **Use Case:** + Parent regions can insert Q/DQ pairs at child region inputs to: + - Add quantization at child boundaries even if the child has no internal Q/DQ + - Override or supplement the child's own boundary Q/DQ decisions + - Apply different quantization schemes based on the parent context + + **Resolution Process:** + 1. Pattern-relative indices (region_index, input_index) are defined once + 2. For each matching parent region, indices resolve to actual child boundaries: + - region_index identifies which child region (in parent's sorted child list) + - input_index identifies which input tensor of that child region + 3. Q/DQ pairs are inserted at the resolved child input tensor locations + + **Example:** + - ChildRegionInputInsertionPoint(region_index=0, input_index=1) + - Resolves to: the second input tensor (index 1) of the first child region (index 0) + - Actual tensor name depends on the specific parent/child region instances + + **Note:** Only applies to COMPOSITE regions. LEAF regions have no children, + so child region insertion points have no effect there. + + **Attributes:** + - region_index: Index of the child region within the parent pattern's sorted child list (0-based) + - input_index: Index of the input tensor for that child region (0-based) + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + # Index of the child region within the parent pattern's sorted child list (0-based) + region_index: int + # Index of the input tensor for that child region (0-based) + input_index: int + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return {"region_index": self.region_index, "input_index": self.input_index} + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ChildRegionInputInsertionPoint": + """Create from dictionary. + + Backward compatible: Ignores obsolete fields like 'child_region_id' + from older serialization formats. + + Args: + data: Dictionary with 'region_index' and 'input_index' keys + + Returns: + ChildRegionInputInsertionPoint instance + """ + # Ignore child_region_id if present in old data + return cls(region_index=data["region_index"], input_index=data["input_index"]) + + def __str__(self) -> str: + """String representation for debugging.""" + return ( + f"ChildRegionInputInsertionPoint(region={self.region_index}, input={self.input_index})" + ) + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a child region input insertion point to actual tensor names for a matching region. + + Converts pattern-relative child region index and input index to the actual tensor + name at that child region's input boundary, then resolves to all node inputs that + consume that tensor. + + Args: + region: The parent region instance matching this pattern + graph: The ONNX graph containing the nodes + + Returns: + Set of ResolvedInsertionPoint objects with actual tensor names. + Returns empty set for LEAF regions (no children). + """ + from modelopt.onnx.quantization.autotune.common import RegionType + + if graph is None: + raise ValueError("graph parameter is required") + + # LEAF regions have no child boundaries + if region.get_type() == RegionType.LEAF: + return set() + + # Get sorted child regions (must match order in RegionPattern._compute_signature_recursive) + children_regions = region.get_children() + children_regions = sorted( + children_regions, key=lambda r: (-r.get_level(), r.get_total_size()) + ) + # Map from pattern-relative child index to actual child region + resolved_ips = set() + assert self.region_index < len(children_regions), "Child region index out of range" + child_region = children_regions[self.region_index] + assert self.input_index < len(child_region.get_inputs()), "Input index out of range" + # Resolve the input tensor name using input_index + tensor_name = child_region.get_inputs()[self.input_index] + assert tensor_name is not None, "Tensor name is required" + resolved_ips.update(resolve_region_io_insertion_points(child_region, graph, tensor_name)) + + return resolved_ips + + @staticmethod + def collect_from_region( + region: "Region", graph: gs.Graph + ) -> list["ChildRegionInputInsertionPoint"]: + """Collect all valid child region input insertion points from a region. + + For COMPOSITE regions, analyzes each child region and identifies all valid + input tensors where Q/DQ pairs could be inserted at child boundaries. + Returns empty list for LEAF regions (no children). + + Args: + region: The parent region to collect insertion points from + graph: The ONNX graph containing the nodes + + Returns: + List of ChildRegionInputInsertionPoint objects representing valid insertion locations + """ + from modelopt.onnx.quantization.autotune.common import RegionType + + child_region_input_insertion_points = [] + + # Only COMPOSITE regions have child boundaries for Q/DQ insertion + if region.get_type() != RegionType.LEAF: + # Get all child regions, sorted for deterministic ordering + # Must match sorting in _compute_signature_recursive to ensure + # insertion point indices align with pattern structure + children_regions = region.get_children() + children_regions = sorted( + children_regions, key=lambda r: (-r.get_level(), r.get_total_size()) + ) + + for local_idx, child_region in enumerate(children_regions): + # Create insertion point for each input tensor of the child region + for input_idx, inp in enumerate(child_region.get_inputs()): + if skip_invalid_insertion_points(graph, inp, child_region): + continue + point = ChildRegionInputInsertionPoint( + # Child region index within parent pattern + region_index=local_idx, + # Input index within child region + input_index=input_idx, + ) + child_region_input_insertion_points.append(point) + + return child_region_input_insertion_points + + +@dataclass(frozen=True) +class RegionOutputInsertionPoint: + """Pattern-relative Q/DQ insertion point at an output location. + + Specifies where to insert Q/DQ pairs at output boundaries. This can be either: + 1. Output from a child region (in COMPOSITE regions) + 2. Output from a node within the region + + **Use Case:** + Parent regions can: + - Add Q/DQ at child region output boundaries + - Add Q/DQ at node outputs within the region + - Control quantization precision as data flows through the region hierarchy + + **Resolution Process:** + 1. Pattern-relative indices are defined once + 2. If output is from a child region: use region_index (node_index is None) + - region_index identifies which child region (in sorted order) + - output_index identifies which output tensor of that child region + 3. If output is from a node: use node_index (region_index is None) + - node_index identifies which node (in sorted order) + - output_index identifies which output tensor of that node + 4. Resolves to the actual tensor name at that output location + + **Examples:** + - RegionOutputInsertionPoint(region_index=0, node_index=None, output_index=0) + → First output of the first child region + - RegionOutputInsertionPoint(region_index=None, node_index=2, output_index=1) + → Second output of the third node + + **Note:** Exactly one of region_index or node_index must be set (the other must be None). + + **Attributes:** + - region_index: Index of child region within parent pattern (0-based), or None + - node_index: Index of node within the region (0-based), or None + - output_index: Index of the output tensor (0-based) + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + # Index of child region within parent pattern (0-based), or None + region_index: int | None + # Index of node within the region (0-based), or None + node_index: int | None + # Index of the output tensor (0-based) + output_index: int + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "region_index": self.region_index, + "node_index": self.node_index, + "output_index": self.output_index, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "RegionOutputInsertionPoint": + """Create from dictionary. + + Args: + data: Dictionary with 'region_index', 'node_index', and 'output_index' keys + + Returns: + RegionOutputInsertionPoint instance + """ + return cls( + region_index=data.get("region_index"), + node_index=data.get("node_index"), + output_index=data["output_index"], + ) + + def __str__(self) -> str: + """String representation for debugging.""" + if self.region_index is not None: + return f"RegionOutputInsertionPoint(region={self.region_index}, output={self.output_index})" + else: + return f"RegionOutputInsertionPoint(node={self.node_index}, output={self.output_index})" + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a region output insertion point to actual tensor names for a matching region. + + Converts pattern-relative indices to the actual tensor name at an output location: + - If region_index is set: Resolves to a child region's output tensor + - If node_index is set: Resolves to a node's output tensor + + Then identifies all node inputs that consume that output tensor. + + Args: + region: The region instance matching this pattern + graph: The ONNX graph containing the nodes + + Returns: + Set of ResolvedInsertionPoint objects with actual tensor names + """ + if graph is None: + raise ValueError("graph parameter is required") + + # Get sorted nodes for node output resolution + nodes_list = list(graph.nodes) + node_indices = sorted(region.get_nodes()) + children_regions = region.get_children() + children_regions = sorted( + children_regions, key=lambda r: (-r.get_level(), r.get_total_size()) + ) + + # Resolve each region output insertion point from the scheme to actual tensor names + resolved_ips = set() + # Handle child region outputs (region_index is set) + if self.region_index is not None: + assert self.region_index < len(children_regions), "Region index out of range" + child_region = children_regions[self.region_index] + assert self.output_index < len(child_region.get_outputs()), "Output index out of range" + tensor_name = child_region.get_outputs()[self.output_index] + assert tensor_name is not None, "Invalid tensor name" + resolved_ips.update( + resolve_region_io_insertion_points(child_region, graph, tensor_name) + ) + # Handle node outputs (node_index is set) + elif self.node_index is not None: + assert self.node_index < len(node_indices), "Node index out of range" + node_idx = node_indices[self.node_index] + assert node_idx < len(nodes_list), "Node index out of range" + node = nodes_list[node_idx] + assert self.output_index < len(node.outputs), "Output index out of range" + tensor = node.outputs[self.output_index] + assert tensor is not None, "Invalid tensor name" + assert hasattr(tensor, "name") and tensor.name, "Tensor name is required" + resolved_ips.update(resolve_region_io_insertion_points(None, graph, tensor.name)) + return resolved_ips + + @staticmethod + def collect_from_region( + region: "Region", graph: gs.Graph + ) -> list["RegionOutputInsertionPoint"]: + """Collect all valid region output insertion points from a region. + + Identifies all valid output tensors (from child regions or nodes) that leave + the region boundary and could have Q/DQ pairs inserted. Only includes outputs + that are actual region outputs (not consumed internally). + + For COMPOSITE regions: + - Collects child region outputs that are also region outputs + - Collects node outputs that are region outputs + + For LEAF regions: + - Only collects node outputs that are region outputs + + Args: + region: The region to collect insertion points from + graph: The ONNX graph containing the nodes + + Returns: + List of RegionOutputInsertionPoint objects representing valid insertion locations + """ + from modelopt.onnx.quantization.autotune.common import RegionType + + nodes_list = list(graph.nodes) + node_indices = sorted(region.get_nodes()) + region_outputs_set = set(region.get_outputs()) + + # Only include outputs that are actual region outputs (leave the region) + region_output_insertion_points = [] + if region.get_type() != RegionType.LEAF: + # For COMPOSITE regions: check if child region output is a region output + children_regions = region.get_children() + children_regions = sorted( + children_regions, key=lambda r: (-r.get_level(), r.get_total_size()) + ) + for local_idx, child_region in enumerate(children_regions): + for output_idx, out in enumerate(child_region.get_outputs()): + if out not in region_outputs_set: + continue + if skip_invalid_insertion_points(graph, out, child_region): + continue + point = RegionOutputInsertionPoint( + region_index=local_idx, + node_index=None, + output_index=output_idx, + ) + region_output_insertion_points.append(point) + # For all regions: check if node output is a region output + for local_idx, node_idx in enumerate(node_indices): + assert node_idx < len(nodes_list), "Node index out of range" + node = nodes_list[node_idx] + for output_idx, out in enumerate(node.outputs): + # Skip if tensor doesn't have a valid name + if not (hasattr(out, "name") and out.name): + continue + # Skip if this output is not a region output (i.e., it's consumed internally) + if out.name not in region_outputs_set: + continue + # Skip if insertion point is invalid (wrong dtype, small size, etc.) + if skip_invalid_insertion_points(graph, out.name, node): + continue + # Create insertion point for valid output tensor + point = RegionOutputInsertionPoint( + region_index=None, + node_index=local_idx, + output_index=output_idx, + ) + region_output_insertion_points.append(point) + + return region_output_insertion_points + + +InsertionPointType = ( + NodeInputInsertionPoint | ChildRegionInputInsertionPoint | RegionOutputInsertionPoint +) + + +def skip_invalid_insertion_points( + graph: gs.Graph, tensor_name: str, region_or_node: "Region | gs.Node" +) -> bool: + """Determine if a tensor should be skipped for Q/DQ insertion. + + Filters out tensors that are not suitable for quantization based on various criteria: + - Boolean and shape operations (not quantizable) + - Fused operation patterns (Conv->BatchNorm->ReLU) + - Operation-specific non-quantizable inputs (weights, biases, BN parameters) + - Non-floating-point tensors (indices, masks) + - Small tensors (scalars, small vectors with < 8 elements) + + Args: + graph: The ONNX graph containing the nodes + tensor_name: Name of the tensor to evaluate + region_or_node: Either a Region or a Node to check for usage of this tensor + + Returns: + True if the insertion point should be skipped, False if it's valid for quantization + """ + from modelopt.onnx.quantization.autotune.common import Region + + if isinstance(region_or_node, Region): + node_indices = region_or_node.get_all_nodes_recursive() + nodes: list[gs.Node] = [graph.nodes[node_idx] for node_idx in node_indices] + else: + assert isinstance(region_or_node, gs.Node) + nodes = [region_or_node] + + for node in nodes: + for input_idx, inp in enumerate(node.inputs): + if hasattr(inp, "name") and inp.name == tensor_name: + # Skip weights of Conv and ConvTranspose, they should be quantized with inputs at same time + if node.op in ["Conv", "ConvTranspose"] and input_idx >= 1: + return True + if node.op in ["Relu", "LeakyRelu", "Softmax"]: + # Conv -> ReLU/LeakyRelu/Softmax + if len(node.inputs) == 1 and len(node.inputs[0].inputs) == 1: + producer = node.inputs[0].inputs[0] + if producer.op in ["Conv", "ConvTranspose"]: + return True + # Conv -> BatchNormalization -> ReLU/LeakyRelu/Softmax + if len(node.inputs) == 1 and len(node.inputs[0].inputs) == 1: + producer = node.inputs[0].inputs[0] + if producer.op == "BatchNormalization": + assert len(producer.inputs) >= 1, ( + "BN node should have more than one inputs" + ) + if len(producer.inputs[0].inputs) == 1: + producer = producer.inputs[0].inputs[0] + if producer.op in ["Conv", "ConvTranspose"]: + return True + # Conv -> BatchNormalization -> ReLU/LeakyRelu/Softmax + if node.op == "BatchNormalization": + assert len(node.inputs) >= 1, "BN node should have more than one inputs" + if len(node.inputs[0].inputs) == 1: + producer = node.inputs[0].inputs[0] + if producer.op in ["Conv", "ConvTranspose"]: + return True + # Filter 1: out boolean operations + if node.op in BOOL_OPERATIONS: + return True + # Filter 2: out shape operations + if node.op in SHAPE_OPERATIONS: + return True + # Filter 3: Skip operation-specific non-quantizable inputs + if node.op in ["BatchNormalization", "Resize"] and input_idx >= 1: + return True + if node.op in ["Conv", "Gemm"] and input_idx >= 2: + return True + # Filter 4: Skip non-floating-point tensors (int/bool indices, masks, etc.) + if hasattr(inp, "dtype") and inp.dtype not in [ + None, + np.float32, + np.float16, + np.float64, + ]: + return True + # Filter 5: Skip small tensors (scalars, small vectors) + if hasattr(inp, "shape") and inp.shape is not None: + if all(isinstance(s, int) for s in inp.shape): + if np.prod(inp.shape) < 8: + return True + return False + + +def has_quantizable_operations(region: "Region", graph: gs.Graph) -> bool: + """Check if a region contains major quantizable operations. + + Args: + region: The region to check + graph: The ONNX graph containing the nodes + + Returns: + True if the region contains major quantizable operations, False otherwise + """ + from modelopt.onnx.quantization.autotune.common import RegionType + + # only check leaf regions for quantizable operations + if region.get_type() == RegionType.LEAF: + region_ops = {graph.nodes[idx].op for idx in region.get_nodes()} + return bool(region_ops.intersection(MAJOR_QUANTIZABLE_OPERATIONS)) + return True + + +def resolve_region_io_insertion_points( + region: "Region | None", graph: gs.Graph, tensor_name: str +) -> set[ResolvedInsertionPoint]: + """Resolve region input/output boundaries to actual Q/DQ insertion points. + + For a given tensor at a region boundary (input or output), this function + identifies all the actual node inputs where Q/DQ pairs should be inserted. + It considers both nodes within the region (if provided) and all users of + the tensor in the graph. + + **Use Cases:** + - Child region inputs: Find all nodes inside the child that consume the input tensor + - Child region outputs: Find all nodes outside the child that consume the output tensor + - Node outputs: Find all nodes that consume the tensor (region can be None) + + Args: + region: The region to search within (or None to search entire graph) + graph: The ONNX graph containing the nodes + tensor_name: Name of the tensor at the region boundary + + Returns: + Set of ResolvedInsertionPoint objects specifying where to insert Q/DQ pairs + """ + resolved_insertion_points = set() + tensor_users_map: dict[str, list[int]] = {} + if hasattr(graph, "tensor_users_map"): + tensor_users_map = graph.tensor_users_map + if not tensor_users_map: + tensor_users_map = get_tensor_consumer_node_indices(graph) + + if region is not None: + for node_idx in region.get_all_nodes_recursive(): + assert node_idx < len(graph.nodes), "Node index out of range" + node = graph.nodes[node_idx] + for input_idx, inp in enumerate(node.inputs): + if inp.name == tensor_name: + ip = ResolvedInsertionPoint( + tensor_name=tensor_name, node_index=node_idx, input_index=input_idx + ) + resolved_insertion_points.add(ip) + + if tensor_name in tensor_users_map: + for node_idx in tensor_users_map[tensor_name]: + node = graph.nodes[node_idx] + for input_idx, inp in enumerate(node.inputs): + if inp.name == tensor_name: + ip = ResolvedInsertionPoint( + tensor_name=tensor_name, node_index=node_idx, input_index=input_idx + ) + resolved_insertion_points.add(ip) + + return resolved_insertion_points + + +def merge_resolved_insertion_points( + graph: gs.Graph, resolved_insertion_points: set[ResolvedInsertionPoint] +) -> set[ResolvedInsertionPoint]: + """Optimize insertion points by merging node-specific insertions into tensor-level insertions. + + When all consumers (users) of a tensor have Q/DQ insertion points, it's more efficient + to insert Q/DQ once at the tensor level rather than at each individual node input. + This reduces the number of Q/DQ nodes in the graph and simplifies the quantization scheme. + + **Optimization Logic:** + - For each tensor with multiple node-specific insertion points: + - If ALL users of the tensor have insertion points → merge to tensor-level insertion + - If SOME users have insertion points → keep node-specific insertions + + Args: + graph: The ONNX graph containing the nodes + resolved_insertion_points: Set of resolved insertion points to optimize + + Returns: + Optimized set of insertion points with merged tensor-level insertions where possible + """ + tensor_users_map = get_tensor_consumer_node_indices(graph) + node_input_insertion_points = { + ip for ip in resolved_insertion_points if ip.node_index is not None + } + tensor_names = {ip.tensor_name for ip in node_input_insertion_points} + + results = resolved_insertion_points.difference(node_input_insertion_points) + for tensor_name in tensor_names: + all_users = set(tensor_users_map[tensor_name]) + qdq_users = { + user for user in node_input_insertion_points if user.tensor_name == tensor_name + } + qdq_user_ids = set({user.node_index for user in qdq_users}) + if all_users == qdq_user_ids: + results.add( + ResolvedInsertionPoint(tensor_name=tensor_name, node_index=None, input_index=None) + ) + else: + results.update(qdq_users) + + return results diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index 67596d5df..63633279e 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -302,6 +302,24 @@ def get_tensor_consumer_nodes( return tensor_consumers +def get_tensor_consumer_node_indices(graph: onnx.GraphProto | gs.Graph) -> dict[str, list[int]]: + """Build a mapping from tensor names to the indices of nodes that use them. + + Args: + graph: ONNX GraphSurgeon graph to analyze + + Returns: + Dictionary mapping tensor names to lists of node indices that consume them + """ + tensor_consumer_map: dict[str, list[int]] = defaultdict(list) + nodes = graph.nodes if isinstance(graph, gs.Graph) else graph.node + for node_idx, node in enumerate(nodes): + inputs = node.inputs if isinstance(node, gs.Node) else node.input + for tensor in inputs: + tensor_consumer_map[tensor.name].append(node_idx) + return tensor_consumer_map + + def filter_quantizable_kgen_heads( cask_fusible_partitions: list[list[Node]], kgen_partitions: list[list[Node]], diff --git a/tests/unit/onnx/quantization/autotune/test_insertion_points.py b/tests/unit/onnx/quantization/autotune/test_insertion_points.py new file mode 100644 index 000000000..726e5606b --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_insertion_points.py @@ -0,0 +1,376 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Comprehensive tests for common data structures in the autotuner. + +Tests: +1. InsertionPoint classes (NodeInputInsertionPoint, RegionOutputInsertionPoint, ChildRegionInputInsertionPoint) +2. InsertionScheme serialization/deserialization +3. InsertionScheme hashing and equality +4. InsertionScheme properties and methods +5. PatternSchemes management +""" + +import os +import sys +import unittest + +import pytest + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from modelopt.onnx.quantization.autotune.common import ( + ChildRegionInputInsertionPoint, + InsertionScheme, + NodeInputInsertionPoint, + RegionOutputInsertionPoint, +) + + +class TestNodeInputInsertionPoint(unittest.TestCase): + """Test NodeInputInsertionPoint functionality.""" + + def test_creation(self): + """Test creating NodeInputInsertionPoint.""" + point = NodeInputInsertionPoint(node_index=5, input_index=2) + assert point.node_index == 5 + assert point.input_index == 2 + print("✓ NodeInputInsertionPoint creation") + + def test_immutability(self): + """Test that NodeInputInsertionPoint is immutable (frozen).""" + point = NodeInputInsertionPoint(node_index=1, input_index=0) + with pytest.raises(AttributeError): + point.node_index = 2 + print("✓ NodeInputInsertionPoint is immutable") + + def test_equality(self): + """Test equality comparison.""" + point1 = NodeInputInsertionPoint(node_index=3, input_index=1) + point2 = NodeInputInsertionPoint(node_index=3, input_index=1) + point3 = NodeInputInsertionPoint(node_index=3, input_index=2) + + assert point1 == point2 + assert point1 != point3 + print("✓ NodeInputInsertionPoint equality") + + def test_hashable(self): + """Test that points can be used in sets and dicts.""" + point1 = NodeInputInsertionPoint(node_index=1, input_index=0) + point2 = NodeInputInsertionPoint(node_index=1, input_index=0) + point3 = NodeInputInsertionPoint(node_index=2, input_index=0) + + point_set = {point1, point2, point3} + assert len(point_set) == 2 # point1 and point2 are the same + print("✓ NodeInputInsertionPoint is hashable") + + def test_serialization(self): + """Test to_dict and from_dict.""" + point = NodeInputInsertionPoint(node_index=7, input_index=3) + + data = point.to_dict() + assert data["node_index"] == 7 + assert data["input_index"] == 3 + + restored = NodeInputInsertionPoint.from_dict(data) + assert point == restored + print("✓ NodeInputInsertionPoint serialization") + + def test_string_representation(self): + """Test __str__ method.""" + point = NodeInputInsertionPoint(node_index=2, input_index=1) + s = str(point) + assert "2" in s + assert "1" in s + print("✓ NodeInputInsertionPoint string representation") + + +class TestRegionOutputInsertionPoint(unittest.TestCase): + """Test RegionOutputInsertionPoint functionality.""" + + def test_creation_with_region_index(self): + """Test creating with region_index (child region output).""" + point = RegionOutputInsertionPoint(region_index=2, node_index=None, output_index=1) + assert point.region_index == 2 + assert point.node_index is None + assert point.output_index == 1 + print("✓ RegionOutputInsertionPoint with region_index") + + def test_creation_with_node_index(self): + """Test creating with node_index (node output).""" + point = RegionOutputInsertionPoint(region_index=None, node_index=5, output_index=0) + assert point.region_index is None + assert point.node_index == 5 + assert point.output_index == 0 + print("✓ RegionOutputInsertionPoint with node_index") + + def test_equality(self): + """Test equality comparison.""" + point1 = RegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) + point2 = RegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) + point3 = RegionOutputInsertionPoint(region_index=None, node_index=1, output_index=0) + + assert point1 == point2 + assert point1 != point3 + print("✓ RegionOutputInsertionPoint equality") + + def test_serialization_region_index(self): + """Test serialization with region_index.""" + point = RegionOutputInsertionPoint(region_index=3, node_index=None, output_index=2) + + data = point.to_dict() + assert data["region_index"] == 3 + assert data["node_index"] is None + assert data["output_index"] == 2 + + restored = RegionOutputInsertionPoint.from_dict(data) + assert point == restored + print("✓ RegionOutputInsertionPoint serialization (region_index)") + + def test_serialization_node_index(self): + """Test serialization with node_index.""" + point = RegionOutputInsertionPoint(region_index=None, node_index=7, output_index=1) + + data = point.to_dict() + assert data["region_index"] is None + assert data["node_index"] == 7 + assert data["output_index"] == 1 + + restored = RegionOutputInsertionPoint.from_dict(data) + assert point == restored + print("✓ RegionOutputInsertionPoint serialization (node_index)") + + def test_string_representation(self): + """Test __str__ method.""" + point1 = RegionOutputInsertionPoint(region_index=2, node_index=None, output_index=1) + s1 = str(point1) + assert "region" in s1.lower() + assert "2" in s1 + + point2 = RegionOutputInsertionPoint(region_index=None, node_index=5, output_index=0) + s2 = str(point2) + assert "node" in s2.lower() + assert "5" in s2 + print("✓ RegionOutputInsertionPoint string representation") + + +class TestChildRegionInputInsertionPoint(unittest.TestCase): + """Test ChildRegionInputInsertionPoint functionality.""" + + def test_creation(self): + """Test creating ChildRegionInputInsertionPoint.""" + point = ChildRegionInputInsertionPoint(region_index=3, input_index=1) + assert point.region_index == 3 + assert point.input_index == 1 + print("✓ ChildRegionInputInsertionPoint creation") + + def test_equality(self): + """Test equality comparison.""" + point1 = ChildRegionInputInsertionPoint(region_index=2, input_index=0) + point2 = ChildRegionInputInsertionPoint(region_index=2, input_index=0) + point3 = ChildRegionInputInsertionPoint(region_index=2, input_index=1) + + assert point1 == point2 + assert point1 != point3 + print("✓ ChildRegionInputInsertionPoint equality") + + def test_serialization(self): + """Test to_dict and from_dict.""" + point = ChildRegionInputInsertionPoint(region_index=5, input_index=2) + + data = point.to_dict() + assert data["region_index"] == 5 + assert data["input_index"] == 2 + + restored = ChildRegionInputInsertionPoint.from_dict(data) + assert point == restored + print("✓ ChildRegionInputInsertionPoint serialization") + + +class TestInsertionScheme(unittest.TestCase): + """Test InsertionScheme functionality.""" + + def test_empty_scheme(self): + """Test empty InsertionScheme.""" + scheme = InsertionScheme() + + assert scheme.is_empty + assert scheme.num_node_insertions == 0 + assert scheme.num_region_insertions == 0 + assert scheme.num_region_output_insertions == 0 + assert not scheme.error + print("✓ Empty InsertionScheme") + + def test_scheme_with_node_inputs(self): + """Test scheme with node input insertion points.""" + scheme = InsertionScheme() + scheme.node_inputs = [NodeInputInsertionPoint(0, 0), NodeInputInsertionPoint(1, 0)] + + assert not scheme.is_empty + assert scheme.num_node_insertions == 2 + print("✓ InsertionScheme with node inputs") + + def test_scheme_with_region_outputs(self): + """Test scheme with region output insertion points.""" + scheme = InsertionScheme() + scheme.region_outputs = [ + RegionOutputInsertionPoint(None, 0, 0), + RegionOutputInsertionPoint(1, None, 0), + ] + + assert not scheme.is_empty + assert scheme.num_region_output_insertions == 2 + print("✓ InsertionScheme with region outputs") + + def test_scheme_with_composite_regions(self): + """Test scheme with composite region insertion points.""" + scheme = InsertionScheme() + scheme.child_region_inputs = [ + ChildRegionInputInsertionPoint(0, 0), + ChildRegionInputInsertionPoint(1, 0), + ] + + assert not scheme.is_empty + assert scheme.num_region_insertions == 2 + print("✓ InsertionScheme with composite regions") + + def test_scheme_hash_empty(self): + """Test hash of empty scheme.""" + scheme1 = InsertionScheme() + scheme2 = InsertionScheme() + + assert scheme1.hash == scheme2.hash + print("✓ Empty scheme hash consistency") + + def test_scheme_hash_with_points(self): + """Test hash with insertion points.""" + scheme1 = InsertionScheme() + scheme1.node_inputs = [NodeInputInsertionPoint(0, 0), NodeInputInsertionPoint(1, 0)] + + scheme2 = InsertionScheme() + scheme2.node_inputs = [NodeInputInsertionPoint(0, 0), NodeInputInsertionPoint(1, 0)] + + scheme3 = InsertionScheme() + scheme3.node_inputs = [ + NodeInputInsertionPoint(0, 0), + NodeInputInsertionPoint(2, 0), # Different + ] + + assert scheme1.hash == scheme2.hash + assert scheme1.hash != scheme3.hash + print("✓ Scheme hash with points") + + def test_scheme_hash_order_independent(self): + """Test that hash is independent of insertion point order.""" + scheme1 = InsertionScheme() + scheme1.node_inputs = [NodeInputInsertionPoint(0, 0), NodeInputInsertionPoint(1, 0)] + + scheme2 = InsertionScheme() + scheme2.node_inputs = [ + NodeInputInsertionPoint(1, 0), + NodeInputInsertionPoint(0, 0), # Reversed order + ] + + # Hash should be the same regardless of order + assert scheme1.hash == scheme2.hash + print("✓ Scheme hash is order-independent") + + def test_serialization_empty(self): + """Test serialization of empty scheme.""" + scheme = InsertionScheme() + + data = scheme.to_dict() + restored = InsertionScheme.from_dict(data) + + assert restored.is_empty + assert restored.latency_ms == float("inf") + assert not restored.error + print("✓ Empty scheme serialization") + + def test_serialization_full(self): + """Test serialization with all types of insertion points.""" + scheme = InsertionScheme() + scheme.node_inputs = [NodeInputInsertionPoint(0, 0)] + scheme.child_region_inputs = [ChildRegionInputInsertionPoint(0, 0)] + scheme.region_outputs = [RegionOutputInsertionPoint(None, 0, 0)] + scheme.latency_ms = 12.5 + scheme.error = False + + data = scheme.to_dict() + restored = InsertionScheme.from_dict(data) + + assert len(restored.node_inputs) == 1 + assert len(restored.child_region_inputs) == 1 + assert len(restored.region_outputs) == 1 + assert restored.latency_ms == 12.5 + assert not restored.error + print("✓ Full scheme serialization") + + def test_serialization_with_error(self): + """Test serialization with error flag.""" + scheme = InsertionScheme() + scheme.error = True + scheme.latency_ms = float("inf") + + data = scheme.to_dict() + restored = InsertionScheme.from_dict(data) + + assert restored.error + assert restored.latency_ms == float("inf") + print("✓ Scheme serialization with error") + + +def run_tests(): + """Run all insertion point and scheme tests.""" + print("=" * 70) + print("Autotuner Insertion Points & Schemes Test Suite") + print("=" * 70) + + # Create test suite + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + # Add test classes + suite.addTests(loader.loadTestsFromTestCase(TestNodeInputInsertionPoint)) + suite.addTests(loader.loadTestsFromTestCase(TestRegionOutputInsertionPoint)) + suite.addTests(loader.loadTestsFromTestCase(TestChildRegionInputInsertionPoint)) + suite.addTests(loader.loadTestsFromTestCase(TestInsertionScheme)) + + # Run with verbose output + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + + # Summary + print("\n" + "=" * 70) + print("Test Summary") + print("=" * 70) + print(f"Tests run: {result.testsRun}") + print(f"Successes: {result.testsRun - len(result.failures) - len(result.errors)}") + print(f"Failures: {len(result.failures)}") + print(f"Errors: {len(result.errors)}") + + if result.wasSuccessful(): + print("\n✓ All insertion point and scheme tests passed!") + return 0 + else: + print("\n✗ Some tests failed") + return 1 + + +if __name__ == "__main__": + sys.exit(run_tests()) diff --git a/tests/unit/onnx/quantization/autotune/test_region.py b/tests/unit/onnx/quantization/autotune/test_region.py new file mode 100644 index 000000000..703d7dd94 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_region.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the Region class in the autotuner. + +Tests region creation, hierarchy, and boundary management. +""" + +import sys +import unittest + +from modelopt.onnx.quantization.autotune.common import Region, RegionType + + +class TestRegion(unittest.TestCase): + """Test Region class functionality.""" + + def test_leaf_region_creation(self): + """Test creating a LEAF region.""" + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + + assert region.get_id() == 1 + assert region.get_level() == 0 + assert region.get_type() == RegionType.LEAF + assert region.get_parent() is None + assert len(region.get_children()) == 0 + print("✓ LEAF region creation") + + def test_composite_region_creation(self): + """Test creating a COMPOSITE region.""" + region = Region(region_id=2, level=1, region_type=RegionType.COMPOSITE) + + assert region.get_id() == 2 + assert region.get_level() == 1 + assert region.get_type() == RegionType.COMPOSITE + print("✓ COMPOSITE region creation") + + def test_root_region_creation(self): + """Test creating a ROOT region.""" + region = Region(region_id=0, level=2, region_type=RegionType.ROOT) + + assert region.get_id() == 0 + assert region.get_level() == 2 + assert region.get_type() == RegionType.ROOT + print("✓ ROOT region creation") + + def test_parent_child_relationship(self): + """Test parent-child relationships.""" + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child1 = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) + + parent.add_child(child1) + parent.add_child(child2) + + assert len(parent.get_children()) == 2 + assert child1.get_parent() == parent + assert child2.get_parent() == parent + assert child1 in parent.get_children() + assert child2 in parent.get_children() + print("✓ Parent-child relationships") + + def test_add_nodes(self): + """Test adding nodes to a region.""" + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + + region.add_node(0) + region.add_node(1) + region.add_node(2) + + assert region.get_size() == 3 + assert 0 in region.get_nodes() + assert 1 in region.get_nodes() + assert 2 in region.get_nodes() + print("✓ Add nodes to region") + + def test_input_output_tensors(self): + """Test setting input and output tensors.""" + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + + # Directly assign to inputs/outputs attributes + region.inputs = ["input_tensor_1", "input_tensor_2"] + region.outputs = ["output_tensor_1"] + + assert len(region.get_inputs()) == 2 + assert len(region.get_outputs()) == 1 + assert "input_tensor_1" in region.get_inputs() + assert "output_tensor_1" in region.get_outputs() + print("✓ Input/output tensors") + + def test_region_size_recursive(self): + """Test recursive size calculation.""" + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child1 = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) + + # Add nodes to children + child1.add_node(0) + child1.add_node(1) + child2.add_node(2) + child2.add_node(3) + child2.add_node(4) + + # Add children to parent + parent.add_child(child1) + parent.add_child(child2) + + # Parent itself might have direct nodes + parent.add_node(5) + + # Recursive count should include all nodes + assert len(parent.get_all_nodes_recursive()) == 6 + print("✓ Recursive size calculation") + + def test_is_leaf(self): + """Test checking if region is LEAF type.""" + leaf = Region(region_id=1, level=0, region_type=RegionType.LEAF) + composite = Region(region_id=2, level=1, region_type=RegionType.COMPOSITE) + + assert leaf.get_type() == RegionType.LEAF + assert composite.get_type() != RegionType.LEAF + print("✓ Region LEAF type check") + + def test_is_composite(self): + """Test checking if region is COMPOSITE type.""" + leaf = Region(region_id=1, level=0, region_type=RegionType.LEAF) + composite = Region(region_id=2, level=1, region_type=RegionType.COMPOSITE) + + assert leaf.get_type() != RegionType.COMPOSITE + assert composite.get_type() == RegionType.COMPOSITE + print("✓ Region COMPOSITE type check") + + def test_hierarchical_structure(self): + """Test complex hierarchical structure.""" + root = Region(region_id=0, level=2, region_type=RegionType.ROOT) + composite1 = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + composite2 = Region(region_id=2, level=1, region_type=RegionType.COMPOSITE) + leaf1 = Region(region_id=3, level=0, region_type=RegionType.LEAF) + leaf2 = Region(region_id=4, level=0, region_type=RegionType.LEAF) + leaf3 = Region(region_id=5, level=0, region_type=RegionType.LEAF) + + # Build hierarchy + root.add_child(composite1) + root.add_child(composite2) + composite1.add_child(leaf1) + composite1.add_child(leaf2) + composite2.add_child(leaf3) + + # Add some nodes + leaf1.add_node(0) + leaf2.add_node(1) + leaf3.add_node(2) + + # Verify structure + assert len(root.get_children()) == 2 + assert len(composite1.get_children()) == 2 + assert len(composite2.get_children()) == 1 + assert len(root.get_all_nodes_recursive()) == 3 + print("✓ Complex hierarchical structure") + + def test_remove_child(self): + """Test removing a child region.""" + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = Region(region_id=2, level=0, region_type=RegionType.LEAF) + + parent.add_child(child) + assert len(parent.get_children()) == 1 + + parent.remove_child(child) + assert len(parent.get_children()) == 0 + assert child.get_parent() is None + print("✓ Remove child region") + + +def run_tests(): + """Run all Region tests.""" + print("=" * 70) + print("Region Class Test Suite") + print("=" * 70) + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + suite.addTests(loader.loadTestsFromTestCase(TestRegion)) + + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + + print("\n" + "=" * 70) + print("Test Summary") + print("=" * 70) + print(f"Tests run: {result.testsRun}") + print(f"Successes: {result.testsRun - len(result.failures) - len(result.errors)}") + print(f"Failures: {len(result.failures)}") + print(f"Errors: {len(result.errors)}") + + if result.wasSuccessful(): + print("\n✓ All Region tests passed!") + return 0 + else: + print("\n✗ Some tests failed") + return 1 + + +if __name__ == "__main__": + sys.exit(run_tests())