vllm.model_executor.layers.fla.ops.solve_tril ¶
Functions:
-
solve_tril–Compute the inverse of the matrix I + A
solve_tril(A, cu_seqlens=None, chunk_indices=None, output_dtype=torch.float) ¶
Compute the inverse of the matrix I + A A should be strictly lower triangular, i.e., A.triu() == 0.
Parameters:
-
(A¶Tensor) –[B, T, H, BT], where BT should only be 16, 32, or 64.
-
(cu_seqlens¶Tensor, default:None) –The cumulative sequence lengths of the input tensor. Default:
None. -
(chunk_indices¶Tensor, default:None) –Pre-computed chunk indices. Default:
None. -
(output_dtype¶dtype, default:float) –The dtype of the output tensor. Default:
torch.float. IfNone, the output dtype will be the same as the input dtype.
Returns:
-
Tensor–(I + A)^-1 with the same shape as A