Skip to content

vllm.compilation.passes.fusion.allreduce_rms_fusion

Classes:

AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern

Bases: BasePattern, VllmPatternReplacement

fused_add variant of AiterAllreduceFusedRMSNormGroupQuantFP8Pattern.

Targets the dominant DSv3.2-style post-attention / post-MLP path: all_reduce -> fused_add_rms_norm -> group_fp8_quant. Returns the FP8 quant output, the residual carry-over, and the per-group scale.

Source code in vllm/compilation/passes/fusion/allreduce_rms_fusion.py
class AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern(
    BasePattern, VllmPatternReplacement
):
    """``fused_add`` variant of ``AiterAllreduceFusedRMSNormGroupQuantFP8Pattern``.

    Targets the dominant DSv3.2-style post-attention / post-MLP path:
    ``all_reduce -> fused_add_rms_norm -> group_fp8_quant``. Returns the
    FP8 quant output, the residual carry-over, and the per-group scale.
    """

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str | None,
        group_size: int = 128,
    ) -> None:
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.dtype = dtype
        self.group_size = group_size
        self.FUSED_AR_RMS_QUANT_OP = (
            rocm_aiter_ops.get_fused_allreduce_rmsnorm_quant_per_group_op()
        )
        self.quant_dtype = current_platform.fp8_dtype()
        self.quant_matcher = MatcherQuantFP8(
            QuantKey(
                dtype=self.quant_dtype,
                scale=ScaleDesc(torch.float32, False, GroupShape(1, group_size)),
                symmetric=True,
            ),
            match_rocm_aiter=True,
        )

    def get_inputs(self) -> list[torch.Tensor]:
        # residual, input, weight
        return [
            self.empty(5, self.group_size),
            self.empty(5, self.group_size),
            self.empty(self.group_size),
        ]

    @property
    def pattern(self):
        def _pattern(
            residual: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            allreduce_output = tensor_model_parallel_all_reduce(input)
            rms, residual_out = vllm.ir.ops.fused_add_rms_norm(
                allreduce_output, residual, weight, self.epsilon
            )
            quant, scale = self.quant_matcher(rms)
            return quant, scale, residual_out

        return _pattern

    @property
    def replacement(self):
        def _replacement(
            residual: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            result = self.FUSED_AR_RMS_QUANT_OP(
                input_=input,
                residual=residual,
                weight=weight.to(input.dtype),
                epsilon=self.epsilon,
                group_size=self.group_size,
            )
            # quant_out, scale_out, residual_out
            return result[0], result[2], result[1]

        return _replacement

AiterAllreduceFusedAddRMSNormGroupQuantWithIndexerPattern

Bases: BasePattern, VllmPatternReplacement

Indexer-fan-out variant of AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern.

Targets the DSv3.2 post-attention / post-MLP path where the post-AR normed activation has two consumers: a per-group FP8 quant for fused_qkv_a_proj and a bf16 rocm_unquantized_gemm for the indexer wk_weights_proj. The single-consumer pattern above cannot fire when this fan-out is present, so without this pattern the standalone FP8 quant kernel survives unfused (~535us / decode step on DSv3.2 MI355X TP4).

Lowers to rocm_aiter_fused_allreduce_rmsnorm_quant_per_group_with_bf16_norm (the emit_bf16=True variant of the AR+RMS+QUANT launcher, which returns FP8 quant + scales + bf16 normed activations in one kernel) and rewires the indexer GEMM onto the emitted bf16 norm output. The RMS output is also a graph output in DSv3.2's residual carry; it is returned as a pattern output so the matcher can substitute the bf16 norm in its place.

The trailing FP8 group-quant is matched via MatcherQuantFP8 (consistent with the sibling patterns above), which traces both QuantFP8.forward_hip and forward_native paths and so matches whichever op the call site lowers to (vllm.triton_per_token_group_quant_fp8 or vllm.rocm_aiter_group_fp8_quant).

Source code in vllm/compilation/passes/fusion/allreduce_rms_fusion.py
class AiterAllreduceFusedAddRMSNormGroupQuantWithIndexerPattern(
    BasePattern, VllmPatternReplacement
):
    """Indexer-fan-out variant of ``AiterAllreduceFusedAddRMSNormGroupQuantFP8Pattern``.

    Targets the DSv3.2 post-attention / post-MLP path where the post-AR normed
    activation has two consumers: a per-group FP8 quant for ``fused_qkv_a_proj``
    *and* a bf16 ``rocm_unquantized_gemm`` for the indexer ``wk_weights_proj``.
    The single-consumer pattern above cannot fire when this fan-out is present,
    so without this pattern the standalone FP8 quant kernel survives unfused
    (~535us / decode step on DSv3.2 MI355X TP4).

    Lowers to ``rocm_aiter_fused_allreduce_rmsnorm_quant_per_group_with_bf16_norm``
    (the ``emit_bf16=True`` variant of the AR+RMS+QUANT launcher, which returns
    FP8 quant + scales + bf16 normed activations in one kernel) and rewires the
    indexer GEMM onto the emitted bf16 norm output. The RMS output is also a
    graph output in DSv3.2's residual carry; it is returned as a pattern output
    so the matcher can substitute the bf16 norm in its place.

    The trailing FP8 group-quant is matched via ``MatcherQuantFP8`` (consistent
    with the sibling patterns above), which traces both ``QuantFP8.forward_hip``
    and ``forward_native`` paths and so matches whichever op the call site
    lowers to (``vllm.triton_per_token_group_quant_fp8`` or
    ``vllm.rocm_aiter_group_fp8_quant``).
    """

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str | None,
        group_size: int = 128,
    ) -> None:
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.dtype = dtype
        self.group_size = group_size
        self.FUSED_AR_RMS_QUANT_BF16_OP = (
            rocm_aiter_ops.get_fused_allreduce_rmsnorm_quant_per_group_with_bf16_norm_op()  # noqa: E501
        )
        self.quant_dtype = current_platform.fp8_dtype()
        self.quant_matcher = MatcherQuantFP8(
            QuantKey(
                dtype=self.quant_dtype,
                scale=ScaleDesc(torch.float32, False, GroupShape(1, group_size)),
                symmetric=True,
            ),
            match_rocm_aiter=True,
        )

    def get_inputs(self) -> list[torch.Tensor]:
        h = self.group_size
        indexer_out = 8
        return [
            self.empty(5, h),
            self.empty(5, h),
            self.empty(h),
            self.empty(indexer_out, h),
        ]

    @property
    def pattern(self):
        eps = self.epsilon

        def _pattern(
            residual: torch.Tensor,
            input_: torch.Tensor,
            norm_weight: torch.Tensor,
            indexer_weight: torch.Tensor,
        ) -> tuple[
            torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
        ]:
            ar_out = tensor_model_parallel_all_reduce(input_)
            rms, res_out = vllm.ir.ops.fused_add_rms_norm(
                ar_out, residual, norm_weight, eps
            )
            q, s = self.quant_matcher(rms)
            idx = torch.ops.vllm.rocm_unquantized_gemm(rms, indexer_weight)
            return q, s, res_out, idx, rms

        return _pattern

    @property
    def replacement(self):
        gs = self.group_size
        eps = self.epsilon

        def _replacement(
            residual: torch.Tensor,
            input_: torch.Tensor,
            norm_weight: torch.Tensor,
            indexer_weight: torch.Tensor,
        ) -> tuple[
            torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
        ]:
            fused = self.FUSED_AR_RMS_QUANT_BF16_OP(
                input_=input_,
                residual=residual,
                weight=norm_weight.to(input_.dtype),
                epsilon=eps,
                group_size=gs,
            )
            quant_out, residual_out, scale_out, bf16_norm = (
                fused[0],
                fused[1],
                fused[2],
                fused[3],
            )
            idx = torch.ops.vllm.rocm_unquantized_gemm(bf16_norm, indexer_weight)
            return quant_out, scale_out, residual_out, idx, bf16_norm

        return _replacement

AiterAllreduceFusedRMSNormGroupQuantFP8Pattern

Bases: BasePattern, VllmPatternReplacement

Fuse AllReduce + RMSNorm + per-group FP8 quant into a single AITER custom op.

Matches the AR-side analogue of AiterRMSFp8GroupQuantPattern in rocm_aiter_fusion.py: all_reduce -> rms_norm -> group_fp8_quant fans out into rocm_aiter_fused_allreduce_rmsnorm_quant_per_group.

Without this pattern, RocmAiterAllReduceFusionPass would fuse the all_reduce + rms_norm half (PR #41825 wires that), but the trailing rocm_aiter_group_fp8_quant would still launch as a separate kernel. That standalone quant accounts for ~535us / decode step on DSv3.2 MI355X TP4 -- this pattern eliminates it by absorbing the quant into the AR epilogue. Group size 128 matches the FP8 block-scaled MM kernel used by DSv3.2's linear weights.

Source code in vllm/compilation/passes/fusion/allreduce_rms_fusion.py
class AiterAllreduceFusedRMSNormGroupQuantFP8Pattern(
    BasePattern, VllmPatternReplacement
):
    """Fuse AllReduce + RMSNorm + per-group FP8 quant into a single AITER
    custom op.

    Matches the AR-side analogue of ``AiterRMSFp8GroupQuantPattern`` in
    ``rocm_aiter_fusion.py``: ``all_reduce -> rms_norm -> group_fp8_quant``
    fans out into ``rocm_aiter_fused_allreduce_rmsnorm_quant_per_group``.

    Without this pattern, ``RocmAiterAllReduceFusionPass`` would fuse the
    ``all_reduce + rms_norm`` half (PR #41825 wires that), but the trailing
    ``rocm_aiter_group_fp8_quant`` would still launch as a separate kernel.
    That standalone quant accounts for ~535us / decode step on DSv3.2 MI355X
    TP4 -- this pattern eliminates it by absorbing the quant into the AR
    epilogue. Group size 128 matches the FP8 block-scaled MM kernel used by
    DSv3.2's linear weights.
    """

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str | None,
        group_size: int = 128,
    ) -> None:
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.dtype = dtype
        self.group_size = group_size
        self.FUSED_AR_RMS_QUANT_OP = (
            rocm_aiter_ops.get_fused_allreduce_rmsnorm_quant_per_group_op()
        )
        self.quant_dtype = current_platform.fp8_dtype()
        self.quant_matcher = MatcherQuantFP8(
            QuantKey(
                dtype=self.quant_dtype,
                scale=ScaleDesc(torch.float32, False, GroupShape(1, group_size)),
                symmetric=True,
            ),
            match_rocm_aiter=True,
        )

    def get_inputs(self) -> list[torch.Tensor]:
        # input, weight; hidden dim must be a group_size multiple so the
        # group quant matcher's example trace is well-defined.
        return [self.empty(5, self.group_size), self.empty(self.group_size)]

    @property
    def pattern(self):
        def _pattern(
            input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
            allreduce_output = tensor_model_parallel_all_reduce(input)
            rms = vllm.ir.ops.rms_norm(allreduce_output, weight, self.epsilon)
            quant, scale = self.quant_matcher(rms)
            return quant, scale

        return _pattern

    @property
    def replacement(self):
        def _replacement(
            input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
            residual = torch.zeros_like(input)
            result = self.FUSED_AR_RMS_QUANT_OP(
                input_=input,
                residual=residual,
                weight=weight.to(input.dtype),
                epsilon=self.epsilon,
                group_size=self.group_size,
            )
            # quant_out, scale_out (residual is unused on the no-add path,
            # mirroring how AiterAllreduceFusedRMSNormPattern drops the
            # residual output)
            return result[0], result[2]

        return _replacement

AllReduceFusedAddGemmaRMSNormPattern

Bases: BasePattern

Gemma-style variant of AllReduceFusedAddRMSNormPattern (with residual).

Source code in vllm/compilation/passes/fusion/allreduce_rms_fusion.py
class AllReduceFusedAddGemmaRMSNormPattern(BasePattern):
    """Gemma-style variant of AllReduceFusedAddRMSNormPattern (with residual)."""

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str | None,
        allreduce_params: FlashInferFusedAllReduceParams,
    ) -> None:
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params

    def get_inputs(self) -> list[torch.Tensor]:
        input = self.empty(5, 16)
        residual = self.empty(5, 16)
        weight = self.empty(16)
        return [residual, input.to(self.dtype), weight]

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
            allreduce_output = tensor_model_parallel_all_reduce(input)
            rms, residual = vllm.ir.ops.fused_add_rms_norm(
                allreduce_output, residual, weight.float() + 1.0, self.epsilon
            )
            return rms, residual

        def replacement(
            residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
            assert flashinfer_comm is not None, "FlashInfer must be enabled"
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=None,
                quant_out=None,
                scale_out=None,
                rms_gamma=weight,
                rms_eps=self.epsilon,
                pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
                weight_bias=1.0,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
            return allreduce[1], allreduce[2]

        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )

        first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]

        pm.register_replacement(
            first_return_only(pattern),  # type: ignore[no-untyped-call]
            first_return_only(replacement),  # type: ignore[no-untyped-call]
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
        )

AllReduceFusedAddRMSNormPattern

Bases: BasePattern

This pattern replaces the allreduce + rms norm (with residual) with fused flashinfer implementation. Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.

Source code in vllm/compilation/passes/fusion/allreduce_rms_fusion.py
class AllReduceFusedAddRMSNormPattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (with residual)
    with fused flashinfer implementation.
    Applies to o_proj + rmsnorm after attn and mlp + rmsnorm before attn.
    """

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str | None,
        allreduce_params: FlashInferFusedAllReduceParams,
    ) -> None:
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params

    def get_inputs(self) -> list[torch.Tensor]:
        input = self.empty(5, 16)
        residual = self.empty(5, 16)
        weight = self.empty(16)

        # input goes through allreduce first, always 16-bit
        return [residual, input.to(self.dtype), weight]

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
            allreduce_output = tensor_model_parallel_all_reduce(input)
            rms, residual = vllm.ir.ops.fused_add_rms_norm(
                allreduce_output, residual, weight, self.epsilon
            )
            return rms, residual

        def replacement(
            residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
            assert flashinfer_comm is not None, "FlashInfer must be enabled"
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=None,
                quant_out=None,
                scale_out=None,
                rms_gamma=weight,
                rms_eps=self.epsilon,
                pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
            # allreduce_in, residual
            return allreduce[1], allreduce[2]

        # extra_check routes a Gemma fp32 gamma to AllReduceFusedAddGemmaRMSNormPattern.
        pm.register_replacement(
            pattern,
            replacement,
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
            extra_check=_norm_input_weight_dtype_match,
        )

        # Same pattern, but only return the output and not residual
        # (helpful for end of graph where residual is not used again)
        first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]

        pm.register_replacement(
            first_return_only(pattern),  # type: ignore[no-untyped-call]
            first_return_only(replacement),  # type: ignore[no-untyped-call]
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
            extra_check=_norm_input_weight_dtype_match,
        )

AllReduceFusedAddRMSNormStaticQuantFP8Pattern

Bases: BasePattern

This pattern replaces the allreduce + rms norm (with residual) + static fp8 quant with fused flashinfer implementation. Applies to o_proj + rmsnorm after attn + quant and mlp + rmsnorm + quant before attn.

Source code in vllm/compilation/passes/fusion/allreduce_rms_fusion.py
class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (with residual)
    + static fp8 quant with fused flashinfer implementation.
    Applies to o_proj + rmsnorm after attn + quant and
    mlp + rmsnorm + quant before attn.
    """

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str | None,
        allreduce_params: FlashInferFusedAllReduceParams,
    ) -> None:
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
        self.quant_dtype = torch.float8_e4m3fn

        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)

    def get_inputs(self) -> list[torch.Tensor]:
        input = self.empty(5, 16)
        residual = self.empty(5, 16)
        weight = self.empty(16)
        _, scale = self.quant_matcher.inputs()

        # input goes through allreduce first, always 16-bit
        return [residual, input.to(self.dtype), weight, scale]

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            residual: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            allreduce_output = tensor_model_parallel_all_reduce(input)
            rms, res = vllm.ir.ops.fused_add_rms_norm(
                allreduce_output, residual, weight, self.epsilon
            )
            quant, _ = self.quant_matcher(rms, scale)

            return quant, res

        def replacement(
            residual: torch.Tensor,
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            result_quant = torch.empty_like(input, dtype=self.quant_dtype)
            assert flashinfer_comm is not None, "FlashInfer must be enabled"
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=None,
                quant_out=result_quant,
                scale_out=None,
                rms_gamma=weight,
                rms_eps=self.epsilon,
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
                ),
                scale_factor=scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
            # quant_out, rms_norm_residual
            return allreduce[4], allreduce[2]

        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )

AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern

Bases: BasePattern

This pattern replaces the allreduce + rms norm (with residual) + static nvfp4 quant with fused flashinfer implementation. Applies to o_proj + rmsnorm after attn + quant and mlp + rmsnorm + quant before attn.

Source code in vllm/compilation/passes/fusion/allreduce_rms_fusion.py
class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (with residual)
    + static nvfp4 quant with fused flashinfer implementation.
    Applies to o_proj + rmsnorm after attn + quant and
    mlp + rmsnorm + quant before attn.
    """

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str | None,
        allreduce_params: FlashInferFusedAllReduceParams,
    ) -> None:
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params

    def get_inputs(self) -> list[torch.Tensor]:
        input = torch.empty([16, 16], device=self.device, dtype=self.dtype)

        residual = torch.empty([16, 16], device=self.device, dtype=self.dtype)
        weight = torch.empty([16, 16], device=self.device, dtype=self.dtype)
        quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
        input_global_scale = torch.empty(
            [1, 1], device=self.device, dtype=torch.float32
        )
        output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)

        return [
            quant_result,
            residual,
            input,
            output_scale,
            weight,
            input_global_scale,
        ]

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            quant_result: torch.Tensor,
            residual: torch.Tensor,
            input: torch.Tensor,
            output_scale: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            allreduce_output = tensor_model_parallel_all_reduce(input)
            rms, residual = vllm.ir.ops.fused_add_rms_norm(
                allreduce_output, residual, weight, self.epsilon
            )
            quant_out_tuple = auto_functionalized(
                STATIC_FP4_QUANT_OP,
                input=rms,
                input_scale=input_global_scale,
                is_sf_swizzled_layout=True,
                output=quant_result,
                output_scale=output_scale,
            )

            # quant_out, allreduce_output, output_scale
            return quant_out_tuple[1], residual, quant_out_tuple[2]

        def replacement(
            quant_result: torch.Tensor,
            residual: torch.Tensor,
            input: torch.Tensor,
            output_scale: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            assert flashinfer_comm is not None, "FlashInfer must be enabled"
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=None,
                quant_out=quant_result,
                scale_out=output_scale,
                rms_gamma=weight,
                rms_eps=self.epsilon,
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
                ),
                scale_factor=input_global_scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
            # quant_out, rms_norm_residual, output_scale
            return allreduce[4], allreduce[2], allreduce[5]

        pm.register_replacement(
            pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
        )

AllReduceFusedRMSNormStaticQuantFP8Pattern

Bases: BasePattern

This pattern replaces the allreduce + rms norm (without residual) + static fp8 quant with fused flashinfer implementation. Applies to allreduce + rmsnorm + quant before attn in the first Transformer block.

Source code in vllm/compilation/passes/fusion/allreduce_rms_fusion.py
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (without residual)
    + static fp8 quant with fused flashinfer implementation.
    Applies to allreduce + rmsnorm + quant before attn
    in the first Transformer block.
    """

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str | None,
        allreduce_params: FlashInferFusedAllReduceParams,
    ) -> None:
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params
        self.quant_dtype = torch.float8_e4m3fn
        self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)

    def get_inputs(self) -> list[torch.Tensor]:
        _, scale = self.quant_matcher.inputs()

        # input, weight
        return [self.empty(5, 16), self.empty(16), scale]

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor,
            weight: torch.Tensor,
            scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor]:
            all_reduce = tensor_model_parallel_all_reduce(input)
            rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
            quant, _ = self.quant_matcher(rms, scale)
            return quant, all_reduce

        def replacement(
            input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
            residual = torch.zeros_like(input)
            result_rms = torch.empty_like(input)
            result_quant = torch.empty_like(input, dtype=self.quant_dtype)
            assert flashinfer_comm is not None, "FlashInfer must be enabled"
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=result_rms,
                quant_out=result_quant,
                scale_out=None,
                rms_gamma=weight,
                rms_eps=self.epsilon,
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant
                ),
                scale_factor=scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )

            # quant_out, allreduce_output
            return allreduce[4], allreduce[1]

        pm.register_replacement(
            pattern,
            replacement,
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
            extra_check=_rms_input_weight_dtype_match,
        )

AllReduceFusedRMSNormStaticQuantNVFP4Pattern

Bases: BasePattern

This pattern replaces the allreduce + rms norm (without residual) + static nvfp4 quant with fused flashinfer implementation. Applies to allreduce + rmsnorm + quant before attn in the first Transformer block.

Source code in vllm/compilation/passes/fusion/allreduce_rms_fusion.py
class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (without residual)
    + static nvfp4 quant with fused flashinfer implementation.
    Applies to allreduce + rmsnorm + quant before attn
    in the first Transformer block.
    """

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str | None,
        allreduce_params: FlashInferFusedAllReduceParams,
    ) -> None:
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params

    def get_inputs(self) -> list[torch.Tensor]:
        input = torch.empty([1, 16, 16], device=self.device, dtype=self.dtype)
        quant_result = torch.empty((16, 8), device=self.device, dtype=torch.uint8)
        input_global_scale = torch.empty(
            [1, 1], device=self.device, dtype=torch.float32
        )
        weight = torch.empty([16], device=self.device, dtype=self.dtype)
        output_scale = torch.empty([128, 4], device=self.device, dtype=torch.int32)

        return [input, quant_result, weight, input_global_scale, output_scale]

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
            output_scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            all_reduce = tensor_model_parallel_all_reduce(input)
            rms = vllm.ir.ops.rms_norm(all_reduce, weight, self.epsilon)
            quant_out_tuple = auto_functionalized(
                STATIC_FP4_QUANT_OP,
                input=rms,
                input_scale=input_global_scale,
                is_sf_swizzled_layout=True,
                output=quant_result,
                output_scale=output_scale,
            )

            # quant_out, allreduce_output, output_scale
            return quant_out_tuple[1], all_reduce, quant_out_tuple[2]

        def replacement(
            input: torch.Tensor,
            quant_result: torch.Tensor,
            weight: torch.Tensor,
            input_global_scale: torch.Tensor,
            output_scale: torch.Tensor,
        ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            residual = torch.zeros_like(input)
            result_rms = torch.empty_like(input)
            assert flashinfer_comm is not None, "FlashInfer must be enabled"
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=result_rms,
                quant_out=quant_result,
                scale_out=output_scale,
                rms_gamma=weight,
                rms_eps=self.epsilon,
                # We don't use norm_out afterwards
                pattern_code=(
                    flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant
                ),
                scale_factor=input_global_scale,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )

            # quant_out, allreduce_output, output_scale
            return allreduce[4], allreduce[1], allreduce[5]

        pm.register_replacement(
            pattern,
            replacement,
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
            extra_check=_rms_input_weight_dtype_match,
        )

AllReduceGemmaRMSNormPattern

Bases: BasePattern

Gemma-style variant of AllReduceRMSNormPattern (no residual).

Source code in vllm/compilation/passes/fusion/allreduce_rms_fusion.py
class AllReduceGemmaRMSNormPattern(BasePattern):
    """Gemma-style variant of AllReduceRMSNormPattern (no residual)."""

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str | None,
        allreduce_params: FlashInferFusedAllReduceParams,
    ) -> None:
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params

    def get_inputs(self) -> list[torch.Tensor]:
        return [self.empty(5, 16), self.empty(16)]

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
            allreduce_output = tensor_model_parallel_all_reduce(input)
            rms = vllm.ir.ops.rms_norm(
                allreduce_output, weight.float() + 1.0, self.epsilon
            )
            return rms, allreduce_output

        def replacement(
            input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
            residual = torch.zeros_like(input)
            rms_result = torch.empty_like(input)
            assert flashinfer_comm is not None, "FlashInfer must be enabled"
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=rms_result,
                quant_out=None,
                scale_out=None,
                rms_gamma=weight,
                rms_eps=self.epsilon,
                pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
                weight_bias=1.0,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
            return allreduce[3], allreduce[1]

        pm.register_replacement(
            pattern,
            replacement,
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
        )

AllReduceRMSNormPattern

Bases: BasePattern

This pattern replaces the allreduce + rms norm (without residual) with fused flashinfer implementation. Applies to allreduce + rmsnorm before attn in the first Transformer block.

Source code in vllm/compilation/passes/fusion/allreduce_rms_fusion.py
class AllReduceRMSNormPattern(BasePattern):
    """
    This pattern replaces the allreduce + rms norm (without residual)
    with fused flashinfer implementation.
    Applies to allreduce + rmsnorm before attn in the first Transformer block.
    """

    def __init__(
        self,
        epsilon: float,
        dtype: torch.dtype,
        device: str | None,
        allreduce_params: FlashInferFusedAllReduceParams,
    ) -> None:
        super().__init__(dtype, device)
        self.epsilon = epsilon
        self.allreduce_params = allreduce_params

    def get_inputs(self) -> list[torch.Tensor]:
        # input, weight
        return [self.empty(5, 16), self.empty(16)]

    def register(self, pm_pass: PatternMatcherPass) -> None:
        def pattern(
            input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
            allreduce_output = tensor_model_parallel_all_reduce(input)
            rms = vllm.ir.ops.rms_norm(allreduce_output, weight, self.epsilon)

            return rms, allreduce_output

        def replacement(
            input: torch.Tensor, weight: torch.Tensor
        ) -> tuple[torch.Tensor, torch.Tensor]:
            residual = torch.zeros_like(input)
            rms_result = torch.empty_like(input)
            assert flashinfer_comm is not None, "FlashInfer must be enabled"
            allreduce = auto_functionalized(
                flashinfer_trtllm_fused_allreduce_norm,
                allreduce_in=input,
                residual=residual,
                norm_out=rms_result,
                quant_out=None,
                scale_out=None,
                rms_gamma=weight,
                rms_eps=self.epsilon,
                pattern_code=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm,
                **self.allreduce_params.get_trtllm_fused_allreduce_kwargs(),
            )
            # rms_result, allreduce_in
            return allreduce[3], allreduce[1]

        pm.register_replacement(
            pattern,
            replacement,
            self.get_inputs(),
            pm.fwd_only,
            pm_pass,
            extra_check=_rms_input_weight_dtype_match,
        )

FlashInferFusedAllReduceParams

Parameters for FlashInfer fused allreduce operations.

Source code in vllm/compilation/passes/fusion/allreduce_rms_fusion.py
class FlashInferFusedAllReduceParams:
    """Parameters for FlashInfer fused allreduce operations."""

    def __init__(
        self,
        world_size: int,
        max_token_num: int = 1024,
    ) -> None:
        self.world_size = world_size
        self.launch_with_pdl = True
        self.fp32_acc = True
        self.max_token_num = max_token_num

    def get_trtllm_fused_allreduce_kwargs(self) -> dict[str, bool | int]:
        return {
            "world_size": self.world_size,
            "launch_with_pdl": self.launch_with_pdl,
            "fp32_acc": self.fp32_acc,
            "max_token_num": self.max_token_num,
        }

_norm_input_weight_dtype_match(match)

Prevent fusion when the norm input and weight dtypes differ (e.g. a Gemma fp32 weight.float()+1 gamma), covering rms_norm and fused_add_rms_norm.

Source code in vllm/compilation/passes/fusion/allreduce_rms_fusion.py
def _norm_input_weight_dtype_match(match: pm.Match) -> bool:
    """Prevent fusion when the norm input and weight dtypes differ (e.g. a Gemma
    fp32 weight.float()+1 gamma), covering rms_norm and fused_add_rms_norm."""
    for node in match.nodes:
        if node.target == _IR_RMS_NORM_OP:
            x, weight = node.args[0], node.args[1]
        elif node.target == _IR_FUSED_ADD_RMS_NORM_OP:
            x, weight = node.args[0], node.args[2]
        else:
            continue
        if isinstance(x, fx.Node) and isinstance(weight, fx.Node):
            return x.meta["val"].dtype == weight.meta["val"].dtype
    return True