Skip to content

vllm.model_executor.kernels.mhc.aiter

Functions:

mhc_pre_aiter(residual, fn, hc_scale, hc_base, rms_eps, hc_pre_eps, hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat, n_splits=1)

Forward pass for mHC pre block.

Parameters:

  • residual

    (Tensor) –

    shape (..., hc_mult, hidden_size), dtype torch.bfloat16

  • fn

    (Tensor) –

    shape (hc_mult3, hc_mult * hidden_size), dtype torch.float32

  • hc_scale

    (Tensor) –

    shape (3,), dtype torch.float32

  • hc_base

    (Tensor) –

    shape (hc_mult3,), dtype torch.float32

  • rms_eps

    (float) –

    RMS normalization epsilon

  • hc_pre_eps

    (float) –

    pre-mix epsilon

  • hc_sinkhorn_eps

    (float) –

    sinkhorn epsilon

  • hc_post_mult_value

    (float) –

    post-mix multiplier value

  • sinkhorn_repeat

    (int) –

    number of sinkhorn iterations

  • n_splits

    (int, default: 1 ) –

    split-k factor;

Returns:

  • post_mix ( Tensor ) –

    shape (..., hc_mult), dtype torch.float32

  • comb_mix ( Tensor ) –

    shape (..., hc_mult, hc_mult), dtype torch.float32

  • layer_input ( Tensor ) –

    shape (..., hidden_size), dtype torch.bfloat16

Source code in vllm/model_executor/kernels/mhc/aiter.py
def mhc_pre_aiter(
    residual: torch.Tensor,
    fn: torch.Tensor,
    hc_scale: torch.Tensor,
    hc_base: torch.Tensor,
    rms_eps: float,
    hc_pre_eps: float,
    hc_sinkhorn_eps: float,
    hc_post_mult_value: float,
    sinkhorn_repeat: int,
    n_splits: int = 1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Forward pass for mHC pre block.

    Args:
        residual: shape (..., hc_mult, hidden_size), dtype torch.bfloat16
        fn: shape (hc_mult3, hc_mult * hidden_size), dtype torch.float32
        hc_scale: shape (3,), dtype torch.float32
        hc_base: shape (hc_mult3,), dtype torch.float32
        rms_eps: RMS normalization epsilon
        hc_pre_eps: pre-mix epsilon
        hc_sinkhorn_eps: sinkhorn epsilon
        hc_post_mult_value: post-mix multiplier value
        sinkhorn_repeat: number of sinkhorn iterations
        n_splits: split-k factor;

    Returns:
        post_mix: shape (..., hc_mult), dtype torch.float32
        comb_mix: shape (..., hc_mult, hc_mult), dtype torch.float32
        layer_input: shape (..., hidden_size), dtype torch.bfloat16
    """

    hidden_size = residual.shape[-1]
    assert hidden_size % 256 == 0
    from vllm._aiter_ops import rocm_aiter_ops

    return rocm_aiter_ops.mhc_pre(
        residual,
        fn,
        hc_scale,
        hc_base,
        rms_eps,
        hc_pre_eps,
        hc_sinkhorn_eps,
        hc_post_mult_value,
        sinkhorn_repeat,
    )