vllm.utils.deep_gemm ¶
Compatibility wrapper for DeepGEMM API changes.
Users of vLLM should always import only these wrappers.
Classes:
Functions:
-
calc_diff–Return a global difference metric for unit tests.
-
fp8_fp4_mqa_logits–Compute MQA logits for a single sequence without KV paging.
-
fp8_fp4_paged_mqa_logits–Compute MQA logits using a paged KV-cache.
-
get_col_major_tma_aligned_tensor–Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor
-
get_paged_mqa_logits_metadata–Build scheduling metadata for paged MQA logits.
-
is_deep_gemm_e8m0_used–Return
Trueif vLLM is configured to use DeepGEMM " -
is_deep_gemm_supported–Return
Trueif DeepGEMM is supported on the current platform.
DeepGemmQuantScaleFMT ¶
Bases: Enum
Methods:
-
from_oracle–Return the pre-initialized oracle decision
-
init_oracle_cache–Initialize the oracle decision and store it in the class cache
Source code in vllm/utils/deep_gemm.py
from_oracle() classmethod ¶
Return the pre-initialized oracle decision
Source code in vllm/utils/deep_gemm.py
init_oracle_cache() classmethod ¶
Initialize the oracle decision and store it in the class cache
Source code in vllm/utils/deep_gemm.py
_import_deep_gemm() cached ¶
Import the deep_gemm module.
Prefers an externally installed deep_gemm package (so users can pin a specific version), then falls back to the vendored copy bundled in the vLLM wheel.
Returns None when neither source is usable.
Source code in vllm/utils/deep_gemm.py
_lazy_init() ¶
Import deep_gemm and resolve symbols on first use.
Source code in vllm/utils/deep_gemm.py
_missing(*_, **__) ¶
Placeholder for unavailable DeepGEMM backend.
Source code in vllm/utils/deep_gemm.py
calc_diff(x, y) ¶
Return a global difference metric for unit tests.
DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element error, causing torch.testing.assert_close to fail. Instead of checking every element, we compute a cosine-style similarity over the whole tensor and report 1 - sim. Once kernel accuracy improves this helper can be removed.
Source code in vllm/utils/deep_gemm.py
fp8_fp4_mqa_logits(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits) ¶
Compute MQA logits for a single sequence without KV paging.
Unified FP8/FP4 dispatch — the underlying DeepGEMM kernel takes q = (values, scales_or_None) where scales is None for FP8 Q (per-token scale is folded into weights) and a packed block-scale tensor for MXFP4 Q.
Parameters:
-
(q¶tuple[Tensor, Tensor | None]) –Tuple
(q_values, q_scale). FP8 path: q_values is [M, H, D] float8_e4m3fn and q_scale is None (per-token scale is folded intoweights). FP4 path: q_values is packed uint8 and q_scale is the companion block-scale tensor. -
(kv¶tuple[Tensor, Tensor]) –Tuple
(k_packed, k_scales)— FP8 layout is [N, D] float8_e4m3fn plus fp32 scales [N]; FP4 layout is packed uint8. -
(weights¶Tensor) –weights of shape [M, H], dtype
torch.float32. -
(cu_seqlen_ks¶Tensor) –Start indices (inclusive) for valid K per query position, shape [M], dtype int32.
-
(cu_seqlen_ke¶Tensor) –End indices (exclusive) for valid K per query position, shape [M], dtype int32.
-
(clean_logits¶bool) –Whether to clean the unfilled logits into
-inf.
Returns:
-
Tensor–Logits tensor of shape [M, N], dtype
torch.float32.
Source code in vllm/utils/deep_gemm.py
fp8_fp4_paged_mqa_logits(q, kv_cache, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits) ¶
Compute MQA logits using a paged KV-cache.
Unified FP8/FP4 dispatch — the underlying DeepGEMM kernel takes q = (values, scales_or_None); pass (q_tensor, None) for the FP8 path and (q_values, q_scale) for MXFP4.
Parameters:
-
(q¶tuple[Tensor, Tensor | None]) –Tuple
(q_values, q_scale). FP8 path: q_values is [B, next_n, H, D] float8_e4m3fn and q_scale is None. FP4 path: q_values is packed uint8 and q_scale is the companion block-scale tensor. -
(kv_cache¶Tensor) –Paged KV-cache. FP8 layout is [num_blocks, block_size, 1, D+4], dtype
torch.uint8, with the last 4 bytes per (block, pos) storing the float dequant scale. -
(weights¶Tensor) –Tensor of shape [B * next_n, H], dtype
torch.float32. -
(context_lens¶Tensor) –Tensor of shape [B], dtype int32; effective context length for each batch element.
-
(block_tables¶Tensor) –Tensor of shape [B, max_blocks], dtype int32; maps logical block indices to physical blocks in the paged cache.
-
(schedule_metadata¶Tensor) –Returned by
get_paged_mqa_logits_metadata; used to distribute work across SMs. -
(max_model_len¶int) –Maximum sequence length used to size the logits output.
-
(clean_logits¶bool) –Whether to clean the unfilled logits into
-inf.
Returns:
Source code in vllm/utils/deep_gemm.py
get_col_major_tma_aligned_tensor(x) ¶
Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor
Source code in vllm/utils/deep_gemm.py
get_paged_mqa_logits_metadata(context_lens, block_size, num_sms) ¶
Build scheduling metadata for paged MQA logits.
Parameters:
-
(context_lens¶Tensor) –Tensor of shape [B], dtype int32; effective context length per batch element.
-
(block_size¶int) –KV-cache block size in tokens (e.g., 64).
-
(num_sms¶int) –Number of SMs available. 132 for Hopper
Returns:
-
Tensor–Backend-specific tensor consumed by
fp8_fp4_paged_mqa_logitsto -
Tensor–schedule work across SMs.
Source code in vllm/utils/deep_gemm.py
is_deep_gemm_e8m0_used() cached ¶
Return True if vLLM is configured to use DeepGEMM " "E8M0 scale on a Hopper or Blackwell-class GPU.
Source code in vllm/utils/deep_gemm.py
is_deep_gemm_supported() cached ¶
Return True if DeepGEMM is supported on the current platform. Currently, only Hopper and Blackwell GPUs are supported.
Source code in vllm/utils/deep_gemm.py
should_auto_disable_deep_gemm(model_type) ¶
Check if DeepGemm should be auto-disabled for this model on Blackwell.
Returns True if the model is known to have accuracy degradation with DeepGemm's E8M0 scale format on Blackwell GPUs (SM100+).
Source code in vllm/utils/deep_gemm.py
tf32_hc_prenorm_gemm(x, fn, out, sqrsum, num_split) ¶
Perform the following computation
out = x.float() @ fn.T sqrsum = x.float().square().sum(-1)
See the caller function for shape requirement