Skip to content

vllm.model_executor.layers.fla.ops.fused_recurrent

Functions:

fused_recurrent_gated_delta_rule(q, k, v, g, beta=None, scale=None, initial_state=None, inplace_final_state=True, cu_seqlens=None, ssm_state_indices=None, num_accepted_tokens=None, use_qk_l2norm_in_kernel=False)

Parameters:

  • q

    (Tensor) –

    queries of shape [B, T, H, K].

  • k

    (Tensor) –

    keys of shape [B, T, H, K].

  • v

    (Tensor) –

    values of shape [B, T, HV, V]. GVA is applied if HV > H.

  • g

    (Tensor) –

    g (decays) of shape [B, T, HV].

  • beta

    (Tensor, default: None ) –

    betas of shape [B, T, HV].

  • scale

    (Optional[int], default: None ) –

    Scale factor for the RetNet attention scores. If not provided, it will default to 1 / sqrt(K). Default: None.

  • initial_state

    (Optional[Tensor], default: None ) –

    Initial state of shape [N, HV, V, K] for N input sequences. For equal-length input sequences, N equals the batch size B. Default: None.

  • inplace_final_state

    (bool, default: True ) –

    bool: Whether to store the final state in-place to save memory. Default: True.

  • cu_seqlens

    (Tensor, default: None ) –

    Cumulative sequence lengths of shape [N+1] used for variable-length training, consistent with the FlashAttention API.

  • ssm_state_indices

    (Optional[Tensor], default: None ) –

    Indices to map the input sequences to the initial/final states.

  • num_accepted_tokens

    (Optional[Tensor], default: None ) –

    Number of accepted tokens for each sequence during decoding.

Returns:

  • o ( Tensor ) –

    Outputs of shape [B, T, HV, V].

  • final_state ( Tensor ) –

    Final state of shape [N, HV, V, K].

Examples:: >>> import torch >>> import torch.nn.functional as F >>> from einops import rearrange >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule # inputs with equal lengths >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 >>> q = torch.randn(B, T, H, K, device='cuda') >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) >>> v = torch.randn(B, T, HV, V, device='cuda') >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() >>> h0 = torch.randn(B, HV, V, K, device='cuda') >>> o, ht = fused_gated_recurrent_delta_rule( q, k, v, g, beta, initial_state=h0, ) # for variable-length inputs, the batch size B is expected to be 1 and cu_seqlens is required >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) # for a batch with 4 sequences, cu_seqlens with 5 start/end positions are expected >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32) >>> o_var, ht_var = fused_gated_recurrent_delta_rule( q, k, v, g, beta, initial_state=h0, cu_seqlens=cu_seqlens )

Source code in vllm/model_executor/layers/fla/ops/fused_recurrent.py
def fused_recurrent_gated_delta_rule(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    g: torch.Tensor,
    beta: torch.Tensor = None,
    scale: float = None,
    initial_state: torch.Tensor = None,
    inplace_final_state: bool = True,
    cu_seqlens: torch.Tensor | None = None,
    ssm_state_indices: torch.Tensor | None = None,
    num_accepted_tokens: torch.Tensor | None = None,
    use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
    r"""
    Args:
        q (torch.Tensor):
            queries of shape `[B, T, H, K]`.
        k (torch.Tensor):
            keys of shape `[B, T, H, K]`.
        v (torch.Tensor):
            values of shape `[B, T, HV, V]`.
            GVA is applied if `HV > H`.
        g (torch.Tensor):
            g (decays) of shape `[B, T, HV]`.
        beta (torch.Tensor):
            betas of shape `[B, T, HV]`.
        scale (Optional[int]):
            Scale factor for the RetNet attention scores.
            If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
        initial_state (Optional[torch.Tensor]):
            Initial state of shape `[N, HV, V, K]` for `N` input sequences.
            For equal-length input sequences, `N` equals the batch size `B`.
            Default: `None`.
        inplace_final_state: bool:
            Whether to store the final state in-place to save memory.
            Default: `True`.
        cu_seqlens (torch.Tensor):
            Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
            consistent with the FlashAttention API.
        ssm_state_indices (Optional[torch.Tensor]):
            Indices to map the input sequences to the initial/final states.
        num_accepted_tokens (Optional[torch.Tensor]):
            Number of accepted tokens for each sequence during decoding.

    Returns:
        o (torch.Tensor):
            Outputs of shape `[B, T, HV, V]`.
        final_state (torch.Tensor):
            Final state of shape `[N, HV, V, K]`.

    Examples::
        >>> import torch
        >>> import torch.nn.functional as F
        >>> from einops import rearrange
        >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
        # inputs with equal lengths
        >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
        >>> q = torch.randn(B, T, H, K, device='cuda')
        >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
        >>> v = torch.randn(B, T, HV, V, device='cuda')
        >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
        >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
        >>> h0 = torch.randn(B, HV, V, K, device='cuda')
        >>> o, ht = fused_gated_recurrent_delta_rule(
            q, k, v, g, beta,
            initial_state=h0,
        )
        # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
        >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
        # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
        >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
        >>> o_var, ht_var = fused_gated_recurrent_delta_rule(
            q, k, v, g, beta,
            initial_state=h0,
            cu_seqlens=cu_seqlens
        )
    """
    if cu_seqlens is not None and q.shape[0] != 1:
        raise ValueError(
            f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
            f"Please flatten variable-length inputs before processing."
        )
    if scale is None:
        scale = k.shape[-1] ** -0.5
    else:
        assert scale > 0, "scale must be positive"
    if beta is None:
        beta = torch.ones_like(q[..., 0])
    o, final_state = FusedRecurrentFunction.apply(
        q,
        k,
        v,
        g,
        beta,
        scale,
        initial_state,
        inplace_final_state,
        cu_seqlens,
        ssm_state_indices,
        num_accepted_tokens,
        use_qk_l2norm_in_kernel,
    )
    return o, final_state