Skip to content

vllm.model_executor.layers.fla.ops.chunk

Functions:

chunk_gated_delta_rule(q, k, v, g, beta, scale=None, initial_state=None, output_final_state=False, cu_seqlens=None, chunk_indices=None, chunk_offsets=None, use_qk_l2norm_in_kernel=False, core_attn_out=None)

Parameters:

  • q

    (Tensor) –

    Queries of shape [B, T, H, K].

  • k

    (Tensor) –

    Keys of shape [B, T, H, K].

  • v

    (Tensor) –

    Values of shape [B, T, H, V].

  • g

    (Tensor) –

    (forget) Gating tensor (in log space!) of shape [B, T, H].

  • beta

    (Tensor) –

    Betas of shape [B, T, H].

  • scale

    (Optional[int], default: None ) –

    Scale factor for the RetNet attention scores. If not provided, it will default to 1 / sqrt(K). Default: None.

  • initial_state

    (Optional[Tensor], default: None ) –

    Initial state of shape [N, H, V, K] for N input sequences. For equal-length input sequences, N equals the batch size B. Default: None.

  • output_final_state

    (Optional[bool], default: False ) –

    Whether to output the final state of shape [N, H, V, K]. Default: False.

  • cu_seqlens

    (Tensor, default: None ) –

    Cumulative sequence lengths of shape [N+1] used for variable-length training, consistent with the FlashAttention API.

Returns: o (torch.Tensor): Outputs of shape [B, T, H, V]. final_state (torch.Tensor): Final state of shape [N, H, V, K] if output_final_state=True else None.

Examples:: >>> import torch >>> import torch.nn.functional as F >>> from einops import rearrange >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule # inputs with equal lengths >>> B, T, H, K, V = 4, 2048, 4, 512, 512 >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) >>> h0 = torch.randn(B, H, V, K, dtype=torch.bfloat16, device='cuda') >>> o, ht = chunk_gated_delta_rule( q, k, v, g, beta, initial_state=h0, output_final_state=True ) # for variable-length inputs, the batch size B is expected to be 1 and cu_seqlens is required >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) # for a batch with 4 sequences, cu_seqlens with 5 start/end positions are expected >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32) >>> o_var, ht_var = chunk_gated_delta_rule( q, k, v, g, beta, initial_state=h0, output_final_state=True, cu_seqlens=cu_seqlens )

Source code in vllm/model_executor/layers/fla/ops/chunk.py
@torch.compiler.disable
def chunk_gated_delta_rule(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g: torch.Tensor,
    beta: torch.Tensor,
    scale: float = None,
    initial_state: torch.Tensor = None,
    output_final_state: bool = False,
    cu_seqlens: torch.Tensor | None = None,
    chunk_indices: torch.Tensor | None = None,
    chunk_offsets: torch.Tensor | None = None,
    use_qk_l2norm_in_kernel: bool = False,
    core_attn_out: torch.Tensor | None = None,
):
    r"""
    Args:
        q (torch.Tensor):
            Queries of shape `[B, T, H, K]`.
        k (torch.Tensor):
            Keys of shape `[B, T, H, K]`.
        v (torch.Tensor):
            Values of shape `[B, T, H, V]`.
        g (torch.Tensor):
            (forget) Gating tensor (in log space!) of shape `[B, T, H]`.
        beta (torch.Tensor):
            Betas of shape `[B, T, H]`.
        scale (Optional[int]):
            Scale factor for the RetNet attention scores.
            If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
        initial_state (Optional[torch.Tensor]):
            Initial state of shape `[N, H, V, K]` for `N` input sequences.
            For equal-length input sequences, `N` equals the batch size `B`.
            Default: `None`.
        output_final_state (Optional[bool]):
            Whether to output the final state of shape `[N, H, V, K]`. Default: `False`.
        cu_seqlens (torch.Tensor):
            Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
            consistent with the FlashAttention API.
    Returns:
        o (torch.Tensor):
            Outputs of shape `[B, T, H, V]`.
        final_state (torch.Tensor):
            Final state of shape `[N, H, V, K]` if `output_final_state=True` else `None`.

    Examples::
        >>> import torch
        >>> import torch.nn.functional as F
        >>> from einops import rearrange
        >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
        # inputs with equal lengths
        >>> B, T, H, K, V = 4, 2048, 4, 512, 512
        >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
        >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
        >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
        >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
        >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
        >>> h0 = torch.randn(B, H, V, K, dtype=torch.bfloat16, device='cuda')
        >>> o, ht = chunk_gated_delta_rule(
            q, k, v, g, beta,
            initial_state=h0,
            output_final_state=True
        )
        # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
        >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
        # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
        >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
        >>> o_var, ht_var = chunk_gated_delta_rule(
            q, k, v, g, beta,
            initial_state=h0,
            output_final_state=True,
            cu_seqlens=cu_seqlens
        )
    """
    assert q.dtype == k.dtype == v.dtype
    assert q.dtype != torch.float32, (
        "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
    )
    assert len(beta.shape) == 3, "beta must be of shape [B, T, H]."
    if cu_seqlens is not None:
        if q.shape[0] != 1:
            raise ValueError(
                f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
                f"Please flatten variable-length inputs before processing."
            )
        if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
            raise ValueError(
                f"The number of initial states is expected to be equal to the number of input sequences, "
                f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
            )
    if scale is None:
        scale = k.shape[-1] ** -0.5
    o, final_state = ChunkGatedDeltaRuleFunction.apply(
        q,
        k,
        v,
        g,
        beta,
        scale,
        initial_state,
        output_final_state,
        cu_seqlens,
        chunk_indices,
        chunk_offsets,
        use_qk_l2norm_in_kernel,
        core_attn_out,
    )
    return o, final_state