Skip to content

vllm.model_executor.offloader.prefetch_ops

Custom ops for prefetch offloader torch.compile + CUDA graph compatibility.

These ops use mutates_args to create data dependencies that prevent the compiler from reordering prefetch/sync operations.

Functions:

_start_prefetch_fake(output_tensor, layer_idx)

Fake implementation for torch.compile tracing.

Source code in vllm/model_executor/offloader/prefetch_ops.py
def _start_prefetch_fake(
    output_tensor: torch.Tensor,
    layer_idx: int,
) -> None:
    """Fake implementation for torch.compile tracing."""
    return

_start_prefetch_impl(output_tensor, layer_idx)

Start async prefetch of layer_idx weights.

Initiates H2D copy on the copy stream for the specified layer.

Parameters:

  • output_tensor

    (Tensor) –

    Output from forward - declared as mutated to prevent torch.compile from reordering this op before the computation that produces output_tensor.

  • layer_idx

    (int) –

    Index of the layer to prefetch.

Source code in vllm/model_executor/offloader/prefetch_ops.py
def _start_prefetch_impl(
    output_tensor: torch.Tensor,
    layer_idx: int,
) -> None:
    """Start async prefetch of layer_idx weights.

    Initiates H2D copy on the copy stream for the specified layer.

    Args:
        output_tensor: Output from forward - declared as mutated to
            prevent torch.compile from reordering this op before the
            computation that produces output_tensor.
        layer_idx: Index of the layer to prefetch.
    """
    get_offloader()._start_prefetch(layer_idx)

_wait_prefetch_fake(input_tensor, layer_idx)

Fake implementation for torch.compile tracing.

Source code in vllm/model_executor/offloader/prefetch_ops.py
def _wait_prefetch_fake(
    input_tensor: torch.Tensor,
    layer_idx: int,
) -> None:
    """Fake implementation for torch.compile tracing."""
    return

_wait_prefetch_impl(input_tensor, layer_idx)

Wait for prefetch of layer_idx to complete.

Synchronizes the compute stream with the copy stream to ensure the prefetched weights are ready for use.

Parameters:

  • input_tensor

    (Tensor) –

    Input to the layer (e.g., hidden_states) - declared as mutated to create data dependency for torch.compile.

  • layer_idx

    (int) –

    Index of the layer to wait for.

Source code in vllm/model_executor/offloader/prefetch_ops.py
def _wait_prefetch_impl(
    input_tensor: torch.Tensor,
    layer_idx: int,
) -> None:
    """Wait for prefetch of layer_idx to complete.

    Synchronizes the compute stream with the copy stream to ensure
    the prefetched weights are ready for use.

    Args:
        input_tensor: Input to the layer (e.g., hidden_states) - declared
            as mutated to create data dependency for torch.compile.
        layer_idx: Index of the layer to wait for.
    """
    get_offloader()._wait_for_layer(layer_idx)

register_prefetch_offloader_ops()

Register custom ops for prefetch offloader.

Must be called before the ops are used. This is typically done at module import time.

Source code in vllm/model_executor/offloader/prefetch_ops.py
def register_prefetch_offloader_ops() -> None:
    """Register custom ops for prefetch offloader.

    Must be called before the ops are used. This is typically done
    at module import time.
    """
    direct_register_custom_op(
        op_name="wait_prefetch",
        op_func=_wait_prefetch_impl,
        mutates_args=["input_tensor"],
        fake_impl=_wait_prefetch_fake,
    )

    direct_register_custom_op(
        op_name="start_prefetch",
        op_func=_start_prefetch_impl,
        mutates_args=["output_tensor"],
        fake_impl=_start_prefetch_fake,
    )