Skip to content

vllm.forward_context

Classes:

Functions:

BatchDescriptor dataclass

Batch descriptor for cudagraph dispatching. We should keep the num of items as minimal as possible to properly and uniquely describe the padded batch for cudagraph.

Attributes:

  • has_lora (bool) –

    Whether this batch has active LoRA adapters.

  • num_active_loras (int) –

    Number of distinct active LoRA adapters in this batch.

  • num_reqs (int | None) –

    Number of requests in the batch. Can be None for PIECEWISE cudagraphs where

  • uniform (bool) –

    True if all the requests in the batch have the same number of tokens.

Source code in vllm/forward_context.py
@dataclass(frozen=True)
class BatchDescriptor:
    """
    Batch descriptor for cudagraph dispatching. We should keep the num of
    items as minimal as possible to properly and uniquely describe the padded
    batch for cudagraph.
    """

    num_tokens: int
    num_reqs: int | None = None
    """
    Number of requests in the batch. Can be None for PIECEWISE cudagraphs where
    the cudagraphs can handle any number of requests.
    """
    uniform: bool = False
    """
    True if all the requests in the batch have the same number of tokens.
    """
    has_lora: bool = False
    """
    Whether this batch has active LoRA adapters.
    """
    num_active_loras: int = 0
    """
    Number of distinct active LoRA adapters in this batch.
    When cudagraph_specialize_lora_count is enabled, separate CUDA graphs
    are captured for each num_active_loras value. This allows kernels
    (like fused_moe_lora) whose grid size depends on num_active_loras
    to be properly captured.
    """

has_lora = False class-attribute instance-attribute

Whether this batch has active LoRA adapters.

num_active_loras = 0 class-attribute instance-attribute

Number of distinct active LoRA adapters in this batch. When cudagraph_specialize_lora_count is enabled, separate CUDA graphs are captured for each num_active_loras value. This allows kernels (like fused_moe_lora) whose grid size depends on num_active_loras to be properly captured.

num_reqs = None class-attribute instance-attribute

Number of requests in the batch. Can be None for PIECEWISE cudagraphs where the cudagraphs can handle any number of requests.

uniform = False class-attribute instance-attribute

True if all the requests in the batch have the same number of tokens.

DPMetadata dataclass

Methods:

  • sp_local_sizes

    Context manager for setting self.local_sizes. Same as self.chunked_sizes

Source code in vllm/forward_context.py
@dataclass
class DPMetadata:
    num_tokens_across_dp_cpu: torch.Tensor

    # NOTE: local_sizes should only be set by the chunked_sizes context manager
    local_sizes: list[int] | None = None

    @staticmethod
    def make(
        parallel_config: ParallelConfig,
        num_tokens: int,
        num_tokens_across_dp_cpu: torch.Tensor,
    ) -> "DPMetadata":
        assert num_tokens_across_dp_cpu is not None
        assert parallel_config.data_parallel_size > 1
        assert parallel_config.is_moe_model is not False
        dp_rank = parallel_config.data_parallel_rank
        batchsize = num_tokens

        # If num_tokens_across_dp is None, it will be computed by all_reduce
        # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize
        assert num_tokens_across_dp_cpu[dp_rank] == batchsize, (
            f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}"
        )
        return DPMetadata(num_tokens_across_dp_cpu)

    @contextmanager
    def sp_local_sizes(self, sequence_parallel_size: int):
        """
        Context manager for setting self.local_sizes. Same as self.chunked_sizes
        but without any chunking.
        """
        self.local_sizes = _compute_sp_num_tokens(
            self.num_tokens_across_dp_cpu, sequence_parallel_size
        )
        try:
            yield self.local_sizes
        finally:
            self.local_sizes = None

    def get_chunk_sizes_across_dp_rank(self) -> list[int] | None:
        assert self.local_sizes is not None
        return self.local_sizes

    # Get the cumulative tokens across sequence parallel ranks.
    # In this case the input to the MoEs will be distributed w.r.t both
    # DP and TP rank.
    # When sp_size==1, this is just the cumulative num tokens across DP.
    def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor:
        num_tokens_across_sp_cpu = (
            self.num_tokens_across_dp_cpu - 1 + sp_size
        ) // sp_size
        num_tokens_across_sp_cpu = num_tokens_across_sp_cpu.repeat_interleave(sp_size)
        return torch.cumsum(num_tokens_across_sp_cpu, dim=0)

sp_local_sizes(sequence_parallel_size)

Context manager for setting self.local_sizes. Same as self.chunked_sizes but without any chunking.

Source code in vllm/forward_context.py
@contextmanager
def sp_local_sizes(self, sequence_parallel_size: int):
    """
    Context manager for setting self.local_sizes. Same as self.chunked_sizes
    but without any chunking.
    """
    self.local_sizes = _compute_sp_num_tokens(
        self.num_tokens_across_dp_cpu, sequence_parallel_size
    )
    try:
        yield self.local_sizes
    finally:
        self.local_sizes = None

ForwardContext dataclass

Attributes:

Source code in vllm/forward_context.py
@dataclass
class ForwardContext:
    # copy from vllm_config.compilation_config.static_forward_context
    no_compile_layers: dict[str, Any]
    attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
    slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]]
    """
    Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
    attention layer to its attention metadata
    Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one
    for each microbatch.
    Set dynamically for each forward pass
    """
    # set dynamically for each forward pass
    dp_metadata: DPMetadata | None = None
    # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE.
    # by default NONE, no cudagraph is used.
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE
    batch_descriptor: BatchDescriptor | None = None

    ubatch_slices: UBatchSlices | None = None

    # If True, bypass the compiled model call, e.g. by using .forward() directly
    skip_compiled: bool = False

    # For torch.compile cold start times, we need to avoid hard-coding
    # any strings into the graph. Right now, the vllm.moe_forward
    # and vllm.moe_forward_shared custom operators hard-code strings into
    # the graph.
    #
    # The workaround is to store a list of the strings that each of those
    # custom ops needs in the ForwardContext (all_moe_layers)
    # as well as a counter (moe_layer_index).
    # The ForwardContext object is alive for the duration of the forward pass.
    # When the custom op needs a layer string, get the next string
    # from all_moe_layers and increment the counter.
    #
    # This assumes that the custom operators will always be executed in
    # order and that torch.compile will not try to reorder these
    # operations with respect to each other.
    #
    # TODO(https://github.com/vllm-project/vllm/issues/31985):
    # There are longer-term solutions, like unwrapping the moe custom operator,
    # that aren't ready yet.
    # We could also treat the string as a "symbolic input" to the graph but
    # the PyTorch-side bits for that aren't ready yet either.
    #
    # If this value is None (like in some tests), then we end up baking the string
    # into the graph. Otherwise, the moe custom ops will pop a string from this list.
    all_moe_layers: list[str] | None = None
    moe_layer_index: int = 0

    additional_kwargs: dict[str, Any] = field(default_factory=dict)

    def __post_init__(self):
        assert self.cudagraph_runtime_mode.is_valid_runtime_mode(), (
            f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
        )

slot_mapping instance-attribute

Type Dict[str, AttentionMetadata] for v1, map from layer_name of each attention layer to its attention metadata Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one for each microbatch. Set dynamically for each forward pass

get_forward_context()

Get the current forward context.

Source code in vllm/forward_context.py
def get_forward_context() -> ForwardContext:
    """Get the current forward context."""
    assert _forward_context is not None, (
        "Forward context is not set. "
        "Please use `set_forward_context` to set the forward context."
    )
    return _forward_context

override_forward_context(forward_context)

A context manager that overrides the current forward context. This is used to override the forward context for a specific forward pass.

Source code in vllm/forward_context.py
@contextmanager
def override_forward_context(forward_context: ForwardContext | None):
    """A context manager that overrides the current forward context.
    This is used to override the forward context for a specific
    forward pass.
    """
    global _forward_context
    prev_context = _forward_context
    _forward_context = forward_context
    try:
        yield
    finally:
        _forward_context = prev_context

set_forward_context(attn_metadata, vllm_config, num_tokens=None, num_tokens_across_dp=None, cudagraph_runtime_mode=CUDAGraphMode.NONE, batch_descriptor=None, ubatch_slices=None, slot_mapping=None, skip_compiled=False)

A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass.

Source code in vllm/forward_context.py
@contextmanager
def set_forward_context(
    attn_metadata: Any,
    vllm_config: VllmConfig,
    num_tokens: int | None = None,
    num_tokens_across_dp: torch.Tensor | None = None,
    cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
    batch_descriptor: BatchDescriptor | None = None,
    ubatch_slices: UBatchSlices | None = None,
    slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
    skip_compiled: bool = False,
):
    """A context manager that stores the current forward context,
    can be attention metadata, etc.
    Here we can inject common logic for every model forward pass.
    """
    global forward_start_time
    need_to_track_batchsize = track_batchsize and attn_metadata is not None
    if need_to_track_batchsize:
        forward_start_time = time.perf_counter()

    dp_metadata: DPMetadata | None = None
    if (
        vllm_config.parallel_config.data_parallel_size > 1
        and vllm_config.parallel_config.is_moe_model is not False
        and (attn_metadata is not None or num_tokens is not None)
    ):
        # If num_tokens_across_dp hasn't already been initialized, then
        # initialize it here. Both DP padding and Microbatching will be
        # disabled.
        if num_tokens_across_dp is None:
            assert ubatch_slices is None
            assert num_tokens is not None
            _, num_tokens_across_dp, _ = coordinate_batch_across_dp(
                num_tokens_unpadded=num_tokens,
                parallel_config=vllm_config.parallel_config,
                allow_microbatching=False,
            )
            assert num_tokens_across_dp is not None
        dp_metadata = DPMetadata.make(
            vllm_config.parallel_config, num_tokens or 0, num_tokens_across_dp
        )

    # Convenience: if cudagraph is used and num_tokens is given, we can just
    # create a batch descriptor here if not given (there's no harm since if it
    # doesn't match in the wrapper it'll fall through).
    if cudagraph_runtime_mode != CUDAGraphMode.NONE and num_tokens is not None:
        batch_descriptor = batch_descriptor or BatchDescriptor(num_tokens=num_tokens)

    additional_kwargs = current_platform.set_additional_forward_context(
        attn_metadata=attn_metadata,
        vllm_config=vllm_config,
        dp_metadata=dp_metadata,
        num_tokens=num_tokens,
        num_tokens_across_dp=num_tokens_across_dp,
        cudagraph_runtime_mode=cudagraph_runtime_mode,
        batch_descriptor=batch_descriptor,
        ubatch_slices=ubatch_slices,
    )

    forward_context = create_forward_context(
        attn_metadata,
        vllm_config,
        dp_metadata,
        cudagraph_runtime_mode,
        batch_descriptor,
        ubatch_slices,
        slot_mapping,
        additional_kwargs,
        skip_compiled,
    )

    try:
        with override_forward_context(forward_context):
            yield
    finally:
        global last_logging_time, batchsize_logging_interval
        if need_to_track_batchsize:
            batchsize = num_tokens
            # we use synchronous scheduling right now,
            # adding a sync point here should not affect
            # scheduling of the next batch
            synchronize = current_platform.synchronize
            if synchronize is not None:
                synchronize()
            now = time.perf_counter()
            # time measurement is in milliseconds
            batchsize_forward_time[batchsize].append((now - forward_start_time) * 1000)
            if now - last_logging_time > batchsize_logging_interval:
                last_logging_time = now
                forward_stats = []
                for bs, times in batchsize_forward_time.items():
                    if len(times) <= 1:
                        # can be cudagraph / profiling run
                        continue
                    medium = torch.quantile(torch.tensor(times), q=0.5).item()
                    medium = round(medium, 2)
                    forward_stats.append((bs, len(times), medium))
                forward_stats.sort(key=lambda x: x[1], reverse=True)
                if forward_stats:
                    logger.info(
                        (
                            "Batchsize forward time stats "
                            "(batchsize, count, median_time(ms)): %s"
                        ),
                        forward_stats,
                    )