Skip to content

vllm.models.deepseek_v4.common.ops.fused_inv_rope_fp8_quant

Fused inverse RoPE + block-scaled FP8 quantization kernel for DeepseekV4 attention.

Output scale format is pre-transformed (MN-major TMA-aligned; FP32 on SM90, INT32-packed UE8M0 on SM100) so fp8_einsum skips transform_sf_into_required_layout.

Functions:

fused_inv_rope_fp8_quant(o, positions, cos_sin_cache, n_groups, heads_per_group, nope_dim=448, rope_dim=64, quant_group_size=128, tma_aligned_scales=False)

Fused inverse RoPE + block-scaled FP8 quantization.

Parameters:

  • o

    (Tensor) –

    Attention output [num_tokens, num_heads, head_dim] bf16.

  • positions

    (Tensor) –

    Token positions [num_tokens] int64.

  • cos_sin_cache

    (Tensor) –

    Precomputed [max_pos, rope_dim] with cos||sin.

  • n_groups

    (int) –

    Number of output groups.

  • heads_per_group

    (int) –

    Heads per group.

  • nope_dim

    (int, default: 448 ) –

    Non-RoPE dimensions per head (default 448).

  • rope_dim

    (int, default: 64 ) –

    RoPE dimensions per head (default 64).

  • quant_group_size

    (int, default: 128 ) –

    FP8 quantization block size (default 128).

  • tma_aligned_scales

    (bool, default: False ) –

    Output INT32 packed UE8M0 for SM100 (True) or FP32 for SM90 (False).

Returns:

  • o_fp8 ( Tensor ) –

    [T, G, D] float8_e4m3fn, strides (D, T*D, 1).

  • o_scale ( Tensor ) –

    Pre-transformed scale tensor for fp8_einsum.

Source code in vllm/models/deepseek_v4/common/ops/fused_inv_rope_fp8_quant.py
def fused_inv_rope_fp8_quant(
    o: torch.Tensor,
    positions: torch.Tensor,
    cos_sin_cache: torch.Tensor,
    n_groups: int,
    heads_per_group: int,
    nope_dim: int = 448,
    rope_dim: int = 64,
    quant_group_size: int = 128,
    tma_aligned_scales: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Fused inverse RoPE + block-scaled FP8 quantization.

    Args:
        o: Attention output [num_tokens, num_heads, head_dim] bf16.
        positions: Token positions [num_tokens] int64.
        cos_sin_cache: Precomputed [max_pos, rope_dim] with cos||sin.
        n_groups: Number of output groups.
        heads_per_group: Heads per group.
        nope_dim: Non-RoPE dimensions per head (default 448).
        rope_dim: RoPE dimensions per head (default 64).
        quant_group_size: FP8 quantization block size (default 128).
        tma_aligned_scales: Output INT32 packed UE8M0 for SM100 (True)
                            or FP32 for SM90 (False).

    Returns:
        o_fp8: [T, G, D] float8_e4m3fn, strides (D, T*D, 1).
        o_scale: Pre-transformed scale tensor for fp8_einsum.
    """
    from vllm.utils.deep_gemm import get_tma_aligned_size

    num_tokens, num_heads, head_dim = o.shape
    assert num_heads == n_groups * heads_per_group
    assert head_dim == nope_dim + rope_dim
    assert head_dim % quant_group_size == 0
    assert nope_dim % quant_group_size == (quant_group_size - rope_dim)
    assert rope_dim % 2 == 0
    assert cos_sin_cache.shape[-1] == rope_dim
    assert cos_sin_cache.dtype == torch.float32

    d = heads_per_group * head_dim
    num_scale_blocks = d // quant_group_size
    chunks_per_head = head_dim // quant_group_size

    fp8_dtype = torch.float8_e4m3fn
    fp8_max = torch.finfo(fp8_dtype).max

    tma_aligned_T = get_tma_aligned_size(num_tokens, 4)
    if tma_aligned_scales:
        packed_sf_k = (num_scale_blocks + 3) // 4
        scale_inner = packed_sf_k
    else:
        scale_inner = num_scale_blocks

    # Run kernel through a custom op so inductor sees an opaque boundary.
    # It's a pytorch bug, see https://github.com/vllm-project/vllm/issues/41106
    fp8_buf, scale_buf = torch.ops.vllm.fused_inv_rope_fp8_quant_kernel(
        o,
        positions,
        cos_sin_cache,
        heads_per_group,
        quant_group_size,
        chunks_per_head,
        nope_dim % quant_group_size,
        rope_dim // 2,
        tma_aligned_scales,
        fp8_max,
        tma_aligned_T,
        num_tokens,
        n_groups,
        d,
        scale_inner,
    )
    return fp8_buf.transpose(0, 1), scale_buf.transpose(0, 1)