Skip to content

vllm.distributed.weight_transfer

Weight transfer engines for syncing model weights from trainers to inference workers.

Modules:

  • base

    Base class for weight transfer engines.

  • factory

    Factory for weight transfer engines with lazy loading.

  • ipc_engine

    IPC-based weight transfer engine using CUDA IPC for communication.

  • nccl_engine

    NCCL-based weight transfer engine.

  • packed_tensor

    Packed tensor utilities for efficient weight transfer.

Classes:

WeightTransferEngine

Bases: ABC, Generic[TInitInfo, TUpdateInfo]

Base class for weight transfer engines that handle transport of model weights from a trainer to inference workers.

This abstraction separates weight transfer transport logic from the worker implementation, allowing different backends (NCCL, CUDA IPC[TODO], RDMA[TODO]) to be plugged in.

Subclasses should define

init_info_cls: Type of backend-specific initialization info update_info_cls: Type of backend-specific update info

Methods:

Source code in vllm/distributed/weight_transfer/base.py
class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]):
    """
    Base class for weight transfer engines that handle transport of model weights
    from a trainer to inference workers.

    This abstraction separates weight transfer transport logic from the worker
    implementation, allowing different backends (NCCL, CUDA IPC[TODO], RDMA[TODO]) to be
    plugged in.

    Subclasses should define:
        init_info_cls: Type of backend-specific initialization info
        update_info_cls: Type of backend-specific update info
    """

    # Subclasses should override these class attributes
    init_info_cls: type[TInitInfo]
    update_info_cls: type[TUpdateInfo]

    def __init__(
        self,
        config: WeightTransferConfig,
        parallel_config: ParallelConfig,
        model: torch.nn.Module,
    ) -> None:
        """
        Initialize the weight transfer engine.

        Args:
            config: The configuration for the weight transfer engine
            parallel_config: The configuration for the parallel setup
            model: The local model instance which will receive the weights
        """
        self.config = config
        self.parallel_config = parallel_config
        self.model = model

    def parse_init_info(self, init_dict: dict[str, Any]) -> TInitInfo:
        """
        Construct typed init info from dict with validation.

        Args:
            init_dict: Dictionary containing backend-specific initialization parameters

        Returns:
            Typed backend-specific init info dataclass

        Raises:
            ValueError: If init_dict is invalid for this backend
        """
        try:
            return self.init_info_cls(**init_dict)
        except TypeError as e:
            raise ValueError(
                f"Invalid init_info for {self.__class__.__name__}: {e}"
            ) from e

    def parse_update_info(self, update_dict: dict[str, Any]) -> TUpdateInfo:
        """
        Construct typed update info from dict with validation.

        Args:
            update_dict: Dictionary containing backend-specific update parameters

        Returns:
            Typed backend-specific update info dataclass

        Raises:
            ValueError: If update_dict is invalid for this backend
        """
        try:
            return self.update_info_cls(**update_dict)
        except TypeError as e:
            raise ValueError(
                f"Invalid update_info for {self.__class__.__name__}: {e}"
            ) from e

    @abstractmethod
    def init_transfer_engine(self, init_info: TInitInfo) -> None:
        """
        Initialize the weight transfer mechanism.
        This is called once at the beginning of training.

        Args:
            init_info: Backend-specific initialization info
        """
        raise NotImplementedError

    @abstractmethod
    def receive_weights(
        self,
        update_info: TUpdateInfo,
        load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
    ) -> None:
        """
        Receive weights from the trainer and load them incrementally.

        Args:
            update_info: Backend-specific update info containing parameter metadata
                        and any backend-specific data
            load_weights: Callable that loads weights into the model. Called
                         incrementally for each weight to avoid OOM.
        """
        raise NotImplementedError

    def receive_sparse_weights(
        self,
        update_info: TUpdateInfo,
        apply_patches: Callable[[list[SparseWeightPatch]], None],
    ) -> None:
        """Receive sparse weight patches from the trainer."""
        raise NotImplementedError(
            f"{self.__class__.__name__} does not support sparse weight updates"
        )

    @abstractmethod
    def shutdown(self) -> None:
        """
        Shutdown the weight transfer engine.
        This should be called when the worker is shutting down.
        """
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def trainer_send_weights(
        iterator: Iterator[tuple[str, torch.Tensor]],
        trainer_args: dict[str, Any] | Any,
    ) -> None:
        """
        Send weights from trainer to inference workers.

        This is a static method that can be called from the trainer process
        to send weights to all inference workers.

        Args:
            iterator: Iterator of model parameters. Returns (name, tensor) tuples.
                     The tensors should be on the appropriate device for the backend.
            trainer_args: Dictionary containing backend-specific arguments needed
                         to send weights. The structure depends on the backend:
                         - NCCL: Contains 'group', 'src', 'packed', etc.
                         - IPC: Contains 'mode' ('http' or 'ray'),
                                'llm_handle' (for Ray), 'url' (for HTTP), etc.

        Example:
            >>> param_iter = ((n, p) for n, p in model.named_parameters())
            >>> engine.trainer_send_weights(param_iter, trainer_args)
        """
        raise NotImplementedError

    @staticmethod
    def trainer_send_sparse_weights(
        _iterator: Iterator[SparseWeightPatch],
        _trainer_args: dict[str, Any] | Any,
    ) -> None:
        """Send sparse weight patches from trainer to inference workers."""
        raise NotImplementedError("Sparse weight updates are not supported")

__init__(config, parallel_config, model)

Initialize the weight transfer engine.

Parameters:

  • config

    (WeightTransferConfig) –

    The configuration for the weight transfer engine

  • parallel_config

    (ParallelConfig) –

    The configuration for the parallel setup

  • model

    (Module) –

    The local model instance which will receive the weights

Source code in vllm/distributed/weight_transfer/base.py
def __init__(
    self,
    config: WeightTransferConfig,
    parallel_config: ParallelConfig,
    model: torch.nn.Module,
) -> None:
    """
    Initialize the weight transfer engine.

    Args:
        config: The configuration for the weight transfer engine
        parallel_config: The configuration for the parallel setup
        model: The local model instance which will receive the weights
    """
    self.config = config
    self.parallel_config = parallel_config
    self.model = model

init_transfer_engine(init_info) abstractmethod

Initialize the weight transfer mechanism. This is called once at the beginning of training.

Parameters:

  • init_info

    (TInitInfo) –

    Backend-specific initialization info

Source code in vllm/distributed/weight_transfer/base.py
@abstractmethod
def init_transfer_engine(self, init_info: TInitInfo) -> None:
    """
    Initialize the weight transfer mechanism.
    This is called once at the beginning of training.

    Args:
        init_info: Backend-specific initialization info
    """
    raise NotImplementedError

parse_init_info(init_dict)

Construct typed init info from dict with validation.

Parameters:

  • init_dict

    (dict[str, Any]) –

    Dictionary containing backend-specific initialization parameters

Returns:

  • TInitInfo

    Typed backend-specific init info dataclass

Raises:

  • ValueError

    If init_dict is invalid for this backend

Source code in vllm/distributed/weight_transfer/base.py
def parse_init_info(self, init_dict: dict[str, Any]) -> TInitInfo:
    """
    Construct typed init info from dict with validation.

    Args:
        init_dict: Dictionary containing backend-specific initialization parameters

    Returns:
        Typed backend-specific init info dataclass

    Raises:
        ValueError: If init_dict is invalid for this backend
    """
    try:
        return self.init_info_cls(**init_dict)
    except TypeError as e:
        raise ValueError(
            f"Invalid init_info for {self.__class__.__name__}: {e}"
        ) from e

parse_update_info(update_dict)

Construct typed update info from dict with validation.

Parameters:

  • update_dict

    (dict[str, Any]) –

    Dictionary containing backend-specific update parameters

Returns:

  • TUpdateInfo

    Typed backend-specific update info dataclass

Raises:

  • ValueError

    If update_dict is invalid for this backend

Source code in vllm/distributed/weight_transfer/base.py
def parse_update_info(self, update_dict: dict[str, Any]) -> TUpdateInfo:
    """
    Construct typed update info from dict with validation.

    Args:
        update_dict: Dictionary containing backend-specific update parameters

    Returns:
        Typed backend-specific update info dataclass

    Raises:
        ValueError: If update_dict is invalid for this backend
    """
    try:
        return self.update_info_cls(**update_dict)
    except TypeError as e:
        raise ValueError(
            f"Invalid update_info for {self.__class__.__name__}: {e}"
        ) from e

receive_sparse_weights(update_info, apply_patches)

Receive sparse weight patches from the trainer.

Source code in vllm/distributed/weight_transfer/base.py
def receive_sparse_weights(
    self,
    update_info: TUpdateInfo,
    apply_patches: Callable[[list[SparseWeightPatch]], None],
) -> None:
    """Receive sparse weight patches from the trainer."""
    raise NotImplementedError(
        f"{self.__class__.__name__} does not support sparse weight updates"
    )

receive_weights(update_info, load_weights) abstractmethod

Receive weights from the trainer and load them incrementally.

Parameters:

  • update_info

    (TUpdateInfo) –

    Backend-specific update info containing parameter metadata and any backend-specific data

  • load_weights

    (Callable[[list[tuple[str, Tensor]]], None]) –

    Callable that loads weights into the model. Called incrementally for each weight to avoid OOM.

Source code in vllm/distributed/weight_transfer/base.py
@abstractmethod
def receive_weights(
    self,
    update_info: TUpdateInfo,
    load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
) -> None:
    """
    Receive weights from the trainer and load them incrementally.

    Args:
        update_info: Backend-specific update info containing parameter metadata
                    and any backend-specific data
        load_weights: Callable that loads weights into the model. Called
                     incrementally for each weight to avoid OOM.
    """
    raise NotImplementedError

shutdown() abstractmethod

Shutdown the weight transfer engine. This should be called when the worker is shutting down.

Source code in vllm/distributed/weight_transfer/base.py
@abstractmethod
def shutdown(self) -> None:
    """
    Shutdown the weight transfer engine.
    This should be called when the worker is shutting down.
    """
    raise NotImplementedError

trainer_send_sparse_weights(_iterator, _trainer_args) staticmethod

Send sparse weight patches from trainer to inference workers.

Source code in vllm/distributed/weight_transfer/base.py
@staticmethod
def trainer_send_sparse_weights(
    _iterator: Iterator[SparseWeightPatch],
    _trainer_args: dict[str, Any] | Any,
) -> None:
    """Send sparse weight patches from trainer to inference workers."""
    raise NotImplementedError("Sparse weight updates are not supported")

trainer_send_weights(iterator, trainer_args) abstractmethod staticmethod

Send weights from trainer to inference workers.

This is a static method that can be called from the trainer process to send weights to all inference workers.

Parameters:

  • iterator

    (Iterator[tuple[str, Tensor]]) –

    Iterator of model parameters. Returns (name, tensor) tuples. The tensors should be on the appropriate device for the backend.

  • trainer_args

    (dict[str, Any] | Any) –

    Dictionary containing backend-specific arguments needed to send weights. The structure depends on the backend: - NCCL: Contains 'group', 'src', 'packed', etc. - IPC: Contains 'mode' ('http' or 'ray'), 'llm_handle' (for Ray), 'url' (for HTTP), etc.

Example

param_iter = ((n, p) for n, p in model.named_parameters()) engine.trainer_send_weights(param_iter, trainer_args)

Source code in vllm/distributed/weight_transfer/base.py
@staticmethod
@abstractmethod
def trainer_send_weights(
    iterator: Iterator[tuple[str, torch.Tensor]],
    trainer_args: dict[str, Any] | Any,
) -> None:
    """
    Send weights from trainer to inference workers.

    This is a static method that can be called from the trainer process
    to send weights to all inference workers.

    Args:
        iterator: Iterator of model parameters. Returns (name, tensor) tuples.
                 The tensors should be on the appropriate device for the backend.
        trainer_args: Dictionary containing backend-specific arguments needed
                     to send weights. The structure depends on the backend:
                     - NCCL: Contains 'group', 'src', 'packed', etc.
                     - IPC: Contains 'mode' ('http' or 'ray'),
                            'llm_handle' (for Ray), 'url' (for HTTP), etc.

    Example:
        >>> param_iter = ((n, p) for n, p in model.named_parameters())
        >>> engine.trainer_send_weights(param_iter, trainer_args)
    """
    raise NotImplementedError

WeightTransferEngineFactory

Factory for creating weight transfer engines with lazy loading.

This factory implements a registry pattern that supports: - Lazy loading: Engine modules are only imported when actually needed - Extensibility: Custom engines can be registered at runtime - Centralized registration: All built-in engines registered in one place

Methods:

  • create_engine

    Create a weight transfer engine instance.

  • register_engine

    Register an engine with lazy-loading or direct class reference.

Source code in vllm/distributed/weight_transfer/factory.py
class WeightTransferEngineFactory:
    """Factory for creating weight transfer engines with lazy loading.

    This factory implements a registry pattern that supports:
    - Lazy loading: Engine modules are only imported when actually needed
    - Extensibility: Custom engines can be registered at runtime
    - Centralized registration: All built-in engines registered in one place
    """

    _registry: dict[str, Callable[[], type[WeightTransferEngine]]] = {}

    @classmethod
    def register_engine(
        cls,
        name: str,
        module_path_or_cls: str | type[WeightTransferEngine],
        class_name: str | None = None,
    ) -> None:
        """Register an engine with lazy-loading or direct class reference.

        Supports two calling conventions:
        1. Lazy loading: register_engine(name, module_path, class_name)
        2. Direct class: register_engine(name, engine_cls)

        Args:
            name: The name to register the engine under (e.g., "nccl")
            module_path_or_cls: Either a module path string for lazy loading,
                or the engine class directly
            class_name: Name of the engine class (required if module_path is string)

        Raises:
            ValueError: If an engine with the same name is already registered
        """
        if name in cls._registry:
            raise ValueError(f"Weight transfer engine '{name}' is already registered.")

        if isinstance(module_path_or_cls, str):
            # Lazy loading path
            module_path = module_path_or_cls
            if class_name is None:
                raise ValueError(
                    "class_name is required when registering with module path"
                )

            def loader() -> type[WeightTransferEngine]:
                module = importlib.import_module(module_path)
                return getattr(module, class_name)

            cls._registry[name] = loader
        else:
            # Direct class registration
            engine_cls = module_path_or_cls
            cls._registry[name] = lambda: engine_cls

    @classmethod
    def create_engine(
        cls,
        config: "WeightTransferConfig",
        parallel_config: "ParallelConfig",
        model: "torch.nn.Module",
    ) -> WeightTransferEngine:
        """Create a weight transfer engine instance.

        Args:
            config: Weight transfer configuration containing the backend name
            parallel_config: Parallel configuration for the engine
            model: The local model instance which will receive the weights

        Returns:
            An initialized weight transfer engine instance

        Raises:
            ValueError: If the backend is not registered
        """
        backend = config.backend
        if backend not in cls._registry:
            available = list(cls._registry.keys())
            raise ValueError(
                f"Invalid weight transfer backend: {backend}. "
                f"Available engines: {available}"
            )
        engine_cls = cls._registry[backend]()

        logger.info(
            "Creating weight transfer engine: %s",
            engine_cls.__name__,
        )

        return engine_cls(config, parallel_config, model)

create_engine(config, parallel_config, model) classmethod

Create a weight transfer engine instance.

Parameters:

  • config

    (WeightTransferConfig) –

    Weight transfer configuration containing the backend name

  • parallel_config

    (ParallelConfig) –

    Parallel configuration for the engine

  • model

    (Module) –

    The local model instance which will receive the weights

Returns:

Raises:

Source code in vllm/distributed/weight_transfer/factory.py
@classmethod
def create_engine(
    cls,
    config: "WeightTransferConfig",
    parallel_config: "ParallelConfig",
    model: "torch.nn.Module",
) -> WeightTransferEngine:
    """Create a weight transfer engine instance.

    Args:
        config: Weight transfer configuration containing the backend name
        parallel_config: Parallel configuration for the engine
        model: The local model instance which will receive the weights

    Returns:
        An initialized weight transfer engine instance

    Raises:
        ValueError: If the backend is not registered
    """
    backend = config.backend
    if backend not in cls._registry:
        available = list(cls._registry.keys())
        raise ValueError(
            f"Invalid weight transfer backend: {backend}. "
            f"Available engines: {available}"
        )
    engine_cls = cls._registry[backend]()

    logger.info(
        "Creating weight transfer engine: %s",
        engine_cls.__name__,
    )

    return engine_cls(config, parallel_config, model)

register_engine(name, module_path_or_cls, class_name=None) classmethod

Register an engine with lazy-loading or direct class reference.

Supports two calling conventions: 1. Lazy loading: register_engine(name, module_path, class_name) 2. Direct class: register_engine(name, engine_cls)

Parameters:

  • name

    (str) –

    The name to register the engine under (e.g., "nccl")

  • module_path_or_cls

    (str | type[WeightTransferEngine]) –

    Either a module path string for lazy loading, or the engine class directly

  • class_name

    (str | None, default: None ) –

    Name of the engine class (required if module_path is string)

Raises:

  • ValueError

    If an engine with the same name is already registered

Source code in vllm/distributed/weight_transfer/factory.py
@classmethod
def register_engine(
    cls,
    name: str,
    module_path_or_cls: str | type[WeightTransferEngine],
    class_name: str | None = None,
) -> None:
    """Register an engine with lazy-loading or direct class reference.

    Supports two calling conventions:
    1. Lazy loading: register_engine(name, module_path, class_name)
    2. Direct class: register_engine(name, engine_cls)

    Args:
        name: The name to register the engine under (e.g., "nccl")
        module_path_or_cls: Either a module path string for lazy loading,
            or the engine class directly
        class_name: Name of the engine class (required if module_path is string)

    Raises:
        ValueError: If an engine with the same name is already registered
    """
    if name in cls._registry:
        raise ValueError(f"Weight transfer engine '{name}' is already registered.")

    if isinstance(module_path_or_cls, str):
        # Lazy loading path
        module_path = module_path_or_cls
        if class_name is None:
            raise ValueError(
                "class_name is required when registering with module path"
            )

        def loader() -> type[WeightTransferEngine]:
            module = importlib.import_module(module_path)
            return getattr(module, class_name)

        cls._registry[name] = loader
    else:
        # Direct class registration
        engine_cls = module_path_or_cls
        cls._registry[name] = lambda: engine_cls