vllm.v1.attention.ops.dcp_alltoall ¶
DCP All-to-All communication backend for attention.
Provides All-to-All (A2A) communication as an alternative to AllGather + ReduceScatter (AG+RS) for Decode Context Parallel (DCP). Instead of gathering the full Q tensor and scattering partial outputs, A2A exchanges partial attention outputs and their LSE values across ranks, then combines them with exact LSE-weighted reduction.
This reduces the number of NCCL calls per attention layer by exchanging the partial output and LSE in a single packed All-to-All payload.
Usage
vllm serve model --tp 16 --dcp 16 --dcp-comm-backend a2a
Reference: https://arxiv.org/abs/2507.07120
Functions:
-
dcp_a2a_lse_reduce–Combine partial attention outputs across DCP ranks using All-to-All.
_lse_weighted_combine(outputs, lses, return_lse=False, is_lse_base_on_e=True) ¶
CPU reference implementation for LSE-weighted combination.
This is a pure PyTorch implementation used for testing and validation.
Parameters:
-
(outputs¶Tensor) –Partial attention outputs [N, B, H, D] N = number of KV shards (ranks) B = batch size (num_tokens) H = number of heads per rank D = head dimension
-
(lses¶Tensor) –Log-sum-exp values [N, B, H]
-
(return_lse¶bool, default:False) –If True, also return the global LSE
-
(is_lse_base_on_e¶bool, default:True) –If True, LSE is base e; if False, base 2
Returns:
Source code in vllm/v1/attention/ops/dcp_alltoall.py
dcp_a2a_lse_reduce(cp_attn_out, cp_attn_lse, cp_group, ctx=None, return_lse=False, is_lse_base_on_e=True) ¶
Combine partial attention outputs across DCP ranks using All-to-All.
The output and fp32 LSE are packed into a single output-dtype buffer, sent with one All-to-All, then unpacked and combined with exact LSE weighting.
Parameters:
-
(cp_attn_out¶Tensor) –[B, H, D] where B=num_tokens, H=total_heads, D=head_dim
-
(cp_attn_lse¶Tensor) –[B, H] log-sum-exp values (fp32)
-
(cp_group¶GroupCoordinator) –GroupCoordinator for DCP communication
-
(ctx¶CPTritonContext | None, default:None) –CPTritonContext (unused, for signature compatibility)
-
(return_lse¶bool, default:False) –If True, also return the combined global LSE
-
(is_lse_base_on_e¶bool, default:True) –If True, LSE is base e; if False, base 2
Returns:
-
Tensor | tuple[Tensor, Tensor]–Combined output [B, H/N, D] (head-scattered)
-
Tensor | tuple[Tensor, Tensor]–If return_lse=True, also returns global_lse [B, H/N]