Skip to content

vllm.lora.layers

Modules:

Classes:

BaseLayerWithLoRA

Bases: Module

Methods:

Source code in vllm/lora/layers/base.py
class BaseLayerWithLoRA(nn.Module):
    @overload
    def slice_lora_a(
        self, lora_a: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]: ...
    @overload
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: ...
    def slice_lora_a(
        self, lora_a: torch.Tensor | list[torch.Tensor | None]
    ) -> torch.Tensor | list[torch.Tensor | None]:
        """Slice lora a if splitting for tensor parallelism."""
        ...

    @overload
    def slice_lora_b(
        self, lora_b: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]: ...
    @overload
    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: ...
    def slice_lora_b(
        self, lora_b: torch.Tensor | list[torch.Tensor | None]
    ) -> torch.Tensor | list[torch.Tensor | None]:
        """Slice lora b if splitting with tensor parallelism."""
        ...

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """Initializes lora matrices."""
        ...

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        ...

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        """Overwrites lora tensors at index."""
        ...

    def set_mapping(
        self,
        punica_wrapper,
    ):
        self.punica_wrapper: PunicaWrapperBase = punica_wrapper

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""
        raise NotImplementedError

can_replace_layer(source_layer, lora_config, packed_modules_list, model_config=None) classmethod

Returns True if the layer can be replaced by this LoRA layer.

Source code in vllm/lora/layers/base.py
@classmethod
def can_replace_layer(
    cls,
    source_layer: nn.Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool:
    """Returns True if the layer can be replaced by this LoRA layer."""
    raise NotImplementedError

create_lora_weights(max_loras, lora_config, model_config=None)

Initializes lora matrices.

Source code in vllm/lora/layers/base.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """Initializes lora matrices."""
    ...

reset_lora(index)

Resets the lora weights at index back to 0.

Source code in vllm/lora/layers/base.py
def reset_lora(self, index: int):
    """Resets the lora weights at index back to 0."""
    ...

set_lora(index, lora_a, lora_b)

Overwrites lora tensors at index.

Source code in vllm/lora/layers/base.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor | list[torch.Tensor],
    lora_b: torch.Tensor | list[torch.Tensor],
):
    """Overwrites lora tensors at index."""
    ...

slice_lora_a(lora_a)

slice_lora_a(
    lora_a: list[torch.Tensor | None],
) -> list[torch.Tensor | None]
slice_lora_a(lora_a: torch.Tensor) -> torch.Tensor

Slice lora a if splitting for tensor parallelism.

Source code in vllm/lora/layers/base.py
def slice_lora_a(
    self, lora_a: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
    """Slice lora a if splitting for tensor parallelism."""
    ...

slice_lora_b(lora_b)

slice_lora_b(
    lora_b: list[torch.Tensor | None],
) -> list[torch.Tensor | None]
slice_lora_b(lora_b: torch.Tensor) -> torch.Tensor

Slice lora b if splitting with tensor parallelism.

Source code in vllm/lora/layers/base.py
def slice_lora_b(
    self, lora_b: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
    """Slice lora b if splitting with tensor parallelism."""
    ...

ColumnParallelLinearWithLoRA

Bases: BaseLinearLayerWithLoRA

LoRA on top of ColumnParallelLinear layer. LoRA B is sliced for tensor parallelism. There are two types for the base_layer: 1. ColumnParallelLinear, e.g.dense_h_to_4h in FalconForCausalLM. 2. MergedColumnParallelLinear, e.g.gate_up_proj in Phi3ForCausalLM.

Methods:

  • forward

    Forward of ColumnParallelLinear

Source code in vllm/lora/layers/column_parallel_linear.py
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
    """
    LoRA on top of ColumnParallelLinear layer.
    LoRA B is sliced for tensor parallelism.
    There are two types for the `base_layer`:
    1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`.
    2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`.
    """

    def __init__(self, base_layer: ColumnParallelLinear) -> None:
        super().__init__(base_layer)
        # The base_layer type is ColumnParallelLinear or
        # MergedColumnParallelLinear, their weight sharding logic is
        # inconsistent when TP is greater than 1.
        self.is_merged_col_linear = isinstance(base_layer, MergedColumnParallelLinear)
        self.output_size = self.base_layer.output_size_per_partition
        # There is only one LoRA layer
        self.n_slices = 1

    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        # Applicable to cases where the base_layer is
        # MergedColumnParallelLinear.
        if self.is_merged_col_linear:
            shard_size = self.output_size // 2
            offset = lora_b.shape[0] // 2

            left_weight = lora_b[
                self.tp_rank * shard_size : (self.tp_rank + 1) * shard_size, :
            ]
            right_weight = lora_b[
                offset + self.tp_rank * shard_size : offset
                + (self.tp_rank + 1) * shard_size,
                :,
            ]
            lora_b = torch.cat([left_weight, right_weight], dim=0)
        # Applicable to cases where the base_layer is
        # ColumnParallelLinear.
        else:
            shard_size = self.output_size
            start_idx = self.tp_rank * shard_size
            end_idx = (self.tp_rank + 1) * shard_size
            lora_b = lora_b[start_idx:end_idx, :]
        return lora_b

    def forward(
        self, input_: torch.Tensor
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
        """Forward of ColumnParallelLinear

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None

        # Matrix multiply.
        output_parallel = self.apply(input_, bias)
        if self.base_layer.gather_output and self.tp_size > 1:
            # All-gather across the partitions.
            output = tensor_model_parallel_all_gather(output_parallel)
        else:
            output = output_parallel

        if not self.base_layer.return_bias:
            return output

        output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
        return output, output_bias

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        if type(source_layer) is maybe_get_oot_by_class(ColumnParallelLinear):
            return True
        if isinstance(source_layer, maybe_get_oot_by_class(MergedColumnParallelLinear)):
            if len(packed_modules_list) != 1:
                return False
            # Exclude layers with 3+ output sizes - those are handled by
            # MergedColumnParallelLinearVariableSliceWithLoRA since this
            # class's slice_lora_b assumes exactly 2 slices.
            return not (
                hasattr(source_layer, "output_sizes")
                and len(source_layer.output_sizes) >= 3
            )
        return False

forward(input_)

Forward of ColumnParallelLinear

Parameters:

  • input_

    (Tensor) –

    Tensor whose last dimension is input_size.

Returns:

Source code in vllm/lora/layers/column_parallel_linear.py
def forward(
    self, input_: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
    """Forward of ColumnParallelLinear

    Args:
        input_: Tensor whose last dimension is `input_size`.

    Returns:
        - output
        - bias
    """
    bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None

    # Matrix multiply.
    output_parallel = self.apply(input_, bias)
    if self.base_layer.gather_output and self.tp_size > 1:
        # All-gather across the partitions.
        output = tensor_model_parallel_all_gather(output_parallel)
    else:
        output = output_parallel

    if not self.base_layer.return_bias:
        return output

    output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
    return output, output_bias

ColumnParallelLinearWithShardedLoRA

Bases: ColumnParallelLinearWithLoRA

Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.

Based on S-LoRA, slicing happens along the rank dim.

Source code in vllm/lora/layers/column_parallel_linear.py
class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
    """
    Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

    # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`,
    # their `lora_a` and `lora_b` have different sharding patterns. After
    # completing the `lora_a` GEMM , a gather operation is performed.
    # Therefore, the sharding of `lora_a` only needs to correspond with the
    # gather operation.
    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        shard_size = self.lora_a_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        lora_a = lora_a[start_idx : start_idx + shard_size, :]
        return lora_a

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        return _mcp_apply(x, bias, self)

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )

FusedMoE3DWithLoRA

Bases: FusedMoEWithLoRA

Methods:

Attributes:

Source code in vllm/lora/layers/fused_moe.py
class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
    def __init__(self, base_layer: MoERunner):
        super().__init__(base_layer)
        self._w13_slices = 1

    def _create_lora_b_weights(self, max_loras, lora_config):
        self.w13_lora_b_stacked: tuple[torch.Tensor] = tuple(
            torch.zeros(
                (
                    max_loras,
                    self.local_num_experts,
                    self.intermediate_size_per_partition * 2,
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self._w13_slices)
        )
        self.w2_lora_b_stacked: tuple[torch.Tensor] = (
            torch.zeros(
                (
                    max_loras,
                    self.local_num_experts,
                    self.hidden_size
                    if not self.fully_sharded
                    else divide(self.hidden_size, self.tp_size),
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            ),
        )

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """Initializes lora matrices."""

        assert isinstance(model_config, PretrainedConfig)
        self._verify_ep_fs(lora_config)
        self._base_model = model_config.architectures[0]
        self.max_loras = lora_config.max_loras
        self.fully_sharded = lora_config.fully_sharded_loras

        self.adapter_enabled = torch.tensor(
            [0] * (max_loras + 1), dtype=torch.int, device=self.device
        )

        self._create_lora_a_weights(max_loras, lora_config)
        self._create_lora_b_weights(max_loras, lora_config)

    def _slice_w13_b(self, w13_lora_b: torch.Tensor):
        if self.tp_size == 1:
            return w13_lora_b

        # w13_lora_b shape (num_experts,output_size,rank)
        shard_size = self.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        # HACK: Currently, only GPT-OSS is in interleaved order
        if self._base_model == "GptOssForCausalLM":
            # For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
            # in the interleaved order, and corresponding LoRA need to be processed.
            w1_lora_b = w13_lora_b[:, ::2, :]
            w3_lora_b = w13_lora_b[:, 1::2, :]
            sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
            sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]

            return torch.stack([sliced_w1_lora_b, sliced_w3_lora_b], dim=2).flatten(
                1, 2
            )
        else:
            slice_size = w13_lora_b.shape[1] // 2
            w1_lora_b = w13_lora_b[:, :slice_size, :]
            w3_lora_b = w13_lora_b[:, slice_size:, :]
            sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
            sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]

            return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        """Overwrites lora tensors at index."""
        # Make mypy happy
        assert isinstance(lora_a, list)
        assert isinstance(lora_b, list)
        assert len(lora_a) == len(lora_b) == 2

        self.reset_lora(index)
        self.adapter_enabled[index] = 1

        w13_lora_a, w2_lora_a = lora_a
        w13_lora_b, w2_lora_b = lora_b

        sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
        sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)

        sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
        sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

        self.w13_lora_a_stacked[0][
            index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
        ].copy_(sliced_w13_lora_a, non_blocking=True)
        self.w2_lora_a_stacked[0][
            index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
        ].copy_(sliced_w2_lora_a, non_blocking=True)

        self.w13_lora_b_stacked[0][
            index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
        ].copy_(sliced_w13_lora_b, non_blocking=True)
        self.w2_lora_b_stacked[0][
            index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
        ].copy_(sliced_w2_lora_b, non_blocking=True)

    @property
    def w13_input_size(self):
        """
        Full size
        """
        return self.w13_lora_a_stacked[0].shape[-1]

    @property
    def w13_output_size(self):
        """
        Full size
        """
        return self.w13_lora_b_stacked[0].shape[-2] * self.tp_size

    @property
    def w2_input_size(self):
        """
        Full size
        """
        return self.w2_lora_a_stacked[0].shape[-1] * self.tp_size

    @property
    def w2_output_size(self):
        """
        Full size
        """
        return self.hidden_size

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""
        # source_layer is MoERunner
        moe_cls = maybe_get_oot_by_class(MoERunner)
        return isinstance(source_layer, moe_cls) and len(packed_modules_list) == 1

w13_input_size property

Full size

w13_output_size property

Full size

w2_input_size property

Full size

w2_output_size property

Full size

can_replace_layer(source_layer, lora_config, packed_modules_list, model_config=None) classmethod

Returns True if the layer can be replaced by this LoRA layer.

Source code in vllm/lora/layers/fused_moe.py
@classmethod
def can_replace_layer(
    cls,
    source_layer: nn.Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool:
    """Returns True if the layer can be replaced by this LoRA layer."""
    # source_layer is MoERunner
    moe_cls = maybe_get_oot_by_class(MoERunner)
    return isinstance(source_layer, moe_cls) and len(packed_modules_list) == 1

create_lora_weights(max_loras, lora_config, model_config=None)

Initializes lora matrices.

Source code in vllm/lora/layers/fused_moe.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """Initializes lora matrices."""

    assert isinstance(model_config, PretrainedConfig)
    self._verify_ep_fs(lora_config)
    self._base_model = model_config.architectures[0]
    self.max_loras = lora_config.max_loras
    self.fully_sharded = lora_config.fully_sharded_loras

    self.adapter_enabled = torch.tensor(
        [0] * (max_loras + 1), dtype=torch.int, device=self.device
    )

    self._create_lora_a_weights(max_loras, lora_config)
    self._create_lora_b_weights(max_loras, lora_config)

set_lora(index, lora_a, lora_b)

Overwrites lora tensors at index.

Source code in vllm/lora/layers/fused_moe.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor | list[torch.Tensor],
    lora_b: torch.Tensor | list[torch.Tensor],
):
    """Overwrites lora tensors at index."""
    # Make mypy happy
    assert isinstance(lora_a, list)
    assert isinstance(lora_b, list)
    assert len(lora_a) == len(lora_b) == 2

    self.reset_lora(index)
    self.adapter_enabled[index] = 1

    w13_lora_a, w2_lora_a = lora_a
    w13_lora_b, w2_lora_b = lora_b

    sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
    sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)

    sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
    sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

    self.w13_lora_a_stacked[0][
        index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
    ].copy_(sliced_w13_lora_a, non_blocking=True)
    self.w2_lora_a_stacked[0][
        index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
    ].copy_(sliced_w2_lora_a, non_blocking=True)

    self.w13_lora_b_stacked[0][
        index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
    ].copy_(sliced_w13_lora_b, non_blocking=True)
    self.w2_lora_b_stacked[0][
        index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
    ].copy_(sliced_w2_lora_b, non_blocking=True)

FusedMoEWithLoRA

Bases: BaseLayerWithLoRA

Methods:

Source code in vllm/lora/layers/fused_moe.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
class FusedMoEWithLoRA(BaseLayerWithLoRA):
    def __init__(self, base_layer: MoERunner) -> None:
        super().__init__()
        self.base_layer = base_layer
        self.moe_config = base_layer.moe_config
        self._shared_experts = base_layer._shared_experts
        self._ep_check()

        routed_experts = self.base_layer.routed_experts
        assert not routed_experts.quant_method.is_monolithic, (
            "Monolithic kernels are not supported for Fused MoE LoRA."
        )

        # Use the MoE-aware TP rank/size: when EP is active, FusedMoE collapses
        # moe_parallel_config.tp_size to 1 (experts are sharded across the
        # TP group instead).
        moe_parallel_config = self.moe_config.moe_parallel_config
        self.tp_size = moe_parallel_config.tp_size
        self.tp_rank = moe_parallel_config.tp_rank
        self.device = _get_lora_device(base_layer)

        self._enable_aux_cuda_stream = envs.VLLM_LORA_ENABLE_DUAL_STREAM
        self._init_lora_stream_context()
        # For non-gated MoE (is_act_and_mul=False), only 1 slice is needed
        # since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3)
        self._w13_slices = 2 if base_layer.moe_config.is_act_and_mul else 1
        # Mirrors per-(lora_id) layout of `self.lora_a_stacked` (built in
        # `create_lora_weights`) so `create_dummy_lora`'s n_slices fallback
        # matches `lora_a_stacked` length under EP.
        self.n_slices = self.local_num_experts * (self._w13_slices + 1)

        routed_experts._ensure_moe_quant_config_init()
        if getattr(routed_experts.quant_method, "supports_internal_mk", False):
            moe_kernel = routed_experts.quant_method.moe_kernel
        else:
            prepare_finalize = MoEPrepareAndFinalizeNoDPEPModular()
            moe_kernel = FusedMoEKernel(
                prepare_finalize,
                routed_experts.quant_method.select_gemm_impl(
                    prepare_finalize, routed_experts
                ),
            )
        assert moe_kernel.supports_lora(), (
            f"{type(moe_kernel.fused_experts).__name__} does not support LoRA. "
            "For unquantized MoE, set moe_backend='triton' or moe_backend='auto' "
            "(auto selects Triton automatically when LoRA is enabled). "
            "For quantized MoE, mix LoRAExpertsMixin into the experts class "
            "and consume self._lora_context in apply()."
        )
        self._moe_kernel = moe_kernel
        self.base_layer._replace_quant_method(
            FusedMoEModularMethod(self.base_layer._quant_method, moe_kernel)
        )

    @property
    def hidden_size(self) -> int:
        return self.moe_config.hidden_dim

    @property
    def local_num_experts(self) -> int:
        return self.moe_config.num_local_experts

    @property
    def global_num_experts(self) -> int:
        return self.moe_config.num_experts

    @property
    def ep_rank(self) -> int:
        return self.moe_config.moe_parallel_config.ep_rank

    @property
    def use_ep(self) -> bool:
        return self.moe_config.moe_parallel_config.use_ep

    @property
    def intermediate_size_per_partition(self) -> int:
        return self.moe_config.intermediate_size_per_partition

    def _init_lora_stream_context(self) -> None:
        self._lora_stream: torch.cuda.Stream | None = None
        self._events: tuple[torch.cuda.Event, ...] | None = None
        if not self._enable_aux_cuda_stream:
            return
        if not current_platform.is_cuda_alike():
            return
        self._lora_stream = _get_lora_aux_cuda_stream()
        # 4 events: 2 per (base GEMM, LoRA) pair so w13 and w2 don't reuse
        # the same event objects; reuse-within-a-pair is fine because the
        # second pair starts only after intermediate_cache1.add_() has joined.
        self._events = tuple(torch.cuda.Event() for _ in range(4))

    def _build_lora_context(self):
        use_dual_stream = (
            self._enable_aux_cuda_stream
            and not self.fully_sharded
            and self._lora_stream is not None
        )
        return MoELoRAContext(
            w13_lora_a_stacked=self.w13_lora_a_stacked,
            w13_lora_b_stacked=self.w13_lora_b_stacked,
            w2_lora_a_stacked=self.w2_lora_a_stacked,
            w2_lora_b_stacked=self.w2_lora_b_stacked,
            adapter_enabled=self.adapter_enabled,
            max_loras=self.max_loras,
            top_k=self.moe_config.experts_per_token,
            w13_num_slices=self._w13_slices,
            fully_sharded=self.fully_sharded,
            tp_rank=self.tp_rank,
            tp_size=self.tp_size,
            local_num_experts=self.local_num_experts,
            punica_wrapper=self.punica_wrapper,
            use_tuned_config=bool(envs.VLLM_TUNED_CONFIG_FOLDER),
            aux_stream=self._lora_stream if use_dual_stream else None,
            events=self._events if use_dual_stream else None,
        )

    def _create_lora_a_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
    ):
        self.w13_lora_a_stacked: tuple[torch.Tensor, ...] = tuple(
            torch.zeros(
                (
                    max_loras,
                    self.local_num_experts,
                    lora_config.max_lora_rank
                    if not self.fully_sharded
                    else divide(lora_config.max_lora_rank, self.tp_size),
                    self.hidden_size,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self._w13_slices)
        )
        self.w2_lora_a_stacked: tuple[torch.Tensor, ...] = (
            torch.zeros(
                (
                    max_loras,
                    self.local_num_experts,
                    lora_config.max_lora_rank,
                    self.intermediate_size_per_partition,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            ),
        )

    def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig):
        self.w13_lora_b_stacked: tuple[torch.Tensor, ...] = tuple(
            torch.zeros(
                (
                    max_loras,
                    self.local_num_experts,
                    self.intermediate_size_per_partition,
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self._w13_slices)
        )
        self.w2_lora_b_stacked: tuple[torch.Tensor, ...] = (
            torch.zeros(
                (
                    max_loras,
                    self.local_num_experts,
                    self.hidden_size
                    if not self.fully_sharded
                    else divide(self.hidden_size, self.tp_size),
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            ),
        )

    def _ep_check(self):
        if self.use_ep:
            moe_config = self.moe_config
            all2all_backend = moe_config.moe_parallel_config.all2all_backend
            assert all2all_backend == "allgather_reducescatter", (
                "Fused MoE LoRA with EP currently only supports "
                f"all2all_backend='allgather_reducescatter', got '{all2all_backend}'."
            )
            assert not moe_config.moe_parallel_config.is_sequence_parallel

    def _verify_ep_fs(self, lora_config: LoRAConfig):
        # EP and fully_sharded LoRA both partition along the same TP group —
        # EP on the expert dim, fully_sharded on the LoRA rank dim — with
        # mutually contradictory assumptions about which rank holds which
        # expert's rank-shard.
        assert not (self.use_ep and lora_config.fully_sharded_loras), (
            "Fused MoE LoRA does not support enable_expert_parallel=True "
            "together with fully_sharded_loras=True. Disable one of them."
        )

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """Initializes lora matrices."""

        self._verify_ep_fs(lora_config)
        self.max_loras = lora_config.max_loras
        self.fully_sharded = lora_config.fully_sharded_loras

        self.adapter_enabled = torch.tensor(
            [0] * (max_loras + 1), dtype=torch.int, device=self.device
        )

        self._create_lora_a_weights(max_loras, lora_config)
        self._create_lora_b_weights(max_loras, lora_config)
        # They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
        # to create a dummy LoRA weights.
        # TODO Optimize this section
        self.lora_a_stacked = []
        self.lora_b_stacked = []
        for lora_id in range(max_loras):
            for experts_id in range(self.local_num_experts):
                # For gated MoE: gate_proj (w1), down_proj (w2), up_proj (w3)
                # For non-gated MoE: up_proj (w1), down_proj (w2)
                self.lora_a_stacked.append(
                    self.w13_lora_a_stacked[0][lora_id][experts_id]
                )
                self.lora_a_stacked.append(
                    self.w2_lora_a_stacked[0][lora_id][experts_id]
                )

                self.lora_b_stacked.append(
                    self.w13_lora_b_stacked[0][lora_id][experts_id]
                )
                self.lora_b_stacked.append(
                    self.w2_lora_b_stacked[0][lora_id][experts_id]
                )

                # Only add w3 (up_proj) for gated MoE (_w13_slices == 2)
                if self._w13_slices == 2:
                    self.lora_a_stacked.append(
                        self.w13_lora_a_stacked[1][lora_id][experts_id]
                    )
                    self.lora_b_stacked.append(
                        self.w13_lora_b_stacked[1][lora_id][experts_id]
                    )

    def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
        """
        Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
        """
        if self.tp_size == 1 or not self.fully_sharded:
            return w13_lora_a

        # w13_lora_a shape (num_experts,rank,input_size)
        current_lora_rank = w13_lora_a.shape[1]
        assert current_lora_rank % self.tp_size == 0
        # Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
        shard_size = self.w13_lora_a_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        return w13_lora_a[:, start_idx:end_idx, :]

    def _slice_w13_b(self, w13_lora_b: torch.Tensor):
        if self.tp_size == 1:
            return w13_lora_b

        # w13_lora_b shape (num_experts,output_size,rank)
        shard_size = self.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size

        return w13_lora_b[:, start_idx:end_idx, :]

    def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
        """
        Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
        """
        if self.tp_size == 1:
            return w2_lora_a
        # w2_lora_a shape (num_experts,rank,input_size)
        shard_size = self.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size

        return w2_lora_a[:, :, start_idx:end_idx]

    def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
        """
        Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
        """
        if self.tp_size == 1 or not self.fully_sharded:
            return w2_lora_b
        # Based on S-LoRA, we slice W2 B along the hidden_size dim.
        # w2_lora_b shape (num_experts,output_size,rank)
        shard_size = self.w2_lora_b_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size

        return w2_lora_b[:, start_idx:end_idx, :]

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        for pos in range(self._w13_slices):
            self.w13_lora_a_stacked[pos][index] = 0
            self.w13_lora_b_stacked[pos][index] = 0

        self.w2_lora_a_stacked[0][index] = 0
        self.w2_lora_b_stacked[0][index] = 0
        self.adapter_enabled[index] = 0

    #

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        """Overwrites lora tensors at index."""
        # Make mypy happy
        assert isinstance(lora_a, list)
        assert isinstance(lora_b, list)

        self.reset_lora(index)
        self.adapter_enabled[index] = 1

        num_experts = self.w13_lora_a_stacked[0].shape[1]

        w1_lora_a, w2_lora_a, w3_lora_a = lora_a
        w1_lora_b, w2_lora_b, w3_lora_b = lora_b

        # EP slicing is done once at add time in
        # LoRAModelManager._slice_moe_lora_ep, so by here the cached
        # tensors already match the local-expert dim of the stacked buffers.
        assert (
            num_experts
            == w1_lora_a.shape[0]
            == w2_lora_a.shape[0]
            == w3_lora_a.shape[0]
        )

        slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
        slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)

        sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
        sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

        self.w13_lora_a_stacked[0][
            index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
        ].copy_(slliced_w1_lora_a, non_blocking=True)

        self.w13_lora_b_stacked[0][
            index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
        ].copy_(slliced_w1_lora_b, non_blocking=True)

        # Only copy w3 (up_proj) for gated MoE (_w13_slices == 2)
        if self._w13_slices == 2:
            slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
            slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)

            self.w13_lora_a_stacked[1][
                index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
            ].copy_(slliced_w3_lora_a, non_blocking=True)

            self.w13_lora_b_stacked[1][
                index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
            ].copy_(slliced_w3_lora_b, non_blocking=True)

        self.w2_lora_a_stacked[0][
            index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
        ].copy_(sliced_w2_lora_a, non_blocking=True)

        self.w2_lora_b_stacked[0][
            index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
        ].copy_(sliced_w2_lora_b, non_blocking=True)

    def set_mapping(self, punica_wrapper):
        super().set_mapping(punica_wrapper)
        lora_context = self._build_lora_context()
        self._moe_kernel.fused_experts.set_lora_context(lora_context)
        prepare_finalize = self._moe_kernel.prepare_finalize
        if hasattr(prepare_finalize, "set_lora_context"):
            prepare_finalize.set_lora_context(lora_context)

    def forward(self, *args, **kwargs):
        return self.base_layer.forward(*args, **kwargs)

    @property
    def quant_method(self):
        return self.base_layer._quant_method

    @property
    def runner(self) -> MoERunner:
        return self.base_layer

    @property
    def is_internal_router(self) -> bool:
        return self.base_layer.is_internal_router

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""

        # source_layer is MoERunner
        moe_cls = maybe_get_oot_by_class(MoERunner)
        return isinstance(source_layer, moe_cls) and len(packed_modules_list) == 2

_slice_w13_a(w13_lora_a)

Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA

Source code in vllm/lora/layers/fused_moe.py
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
    """
    Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
    """
    if self.tp_size == 1 or not self.fully_sharded:
        return w13_lora_a

    # w13_lora_a shape (num_experts,rank,input_size)
    current_lora_rank = w13_lora_a.shape[1]
    assert current_lora_rank % self.tp_size == 0
    # Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
    shard_size = self.w13_lora_a_stacked[0].shape[2]
    start_idx = self.tp_rank * shard_size
    end_idx = (self.tp_rank + 1) * shard_size
    return w13_lora_a[:, start_idx:end_idx, :]

_slice_w2_a(w2_lora_a)

Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA

Source code in vllm/lora/layers/fused_moe.py
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
    """
    Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
    """
    if self.tp_size == 1:
        return w2_lora_a
    # w2_lora_a shape (num_experts,rank,input_size)
    shard_size = self.intermediate_size_per_partition
    start_idx = self.tp_rank * shard_size
    end_idx = (self.tp_rank + 1) * shard_size

    return w2_lora_a[:, :, start_idx:end_idx]

_slice_w2_b(w2_lora_b)

Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA

Source code in vllm/lora/layers/fused_moe.py
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
    """
    Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
    """
    if self.tp_size == 1 or not self.fully_sharded:
        return w2_lora_b
    # Based on S-LoRA, we slice W2 B along the hidden_size dim.
    # w2_lora_b shape (num_experts,output_size,rank)
    shard_size = self.w2_lora_b_stacked[0].shape[2]
    start_idx = self.tp_rank * shard_size
    end_idx = (self.tp_rank + 1) * shard_size

    return w2_lora_b[:, start_idx:end_idx, :]

can_replace_layer(source_layer, lora_config, packed_modules_list, model_config=None) classmethod

Returns True if the layer can be replaced by this LoRA layer.

Source code in vllm/lora/layers/fused_moe.py
@classmethod
def can_replace_layer(
    cls,
    source_layer: nn.Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool:
    """Returns True if the layer can be replaced by this LoRA layer."""

    # source_layer is MoERunner
    moe_cls = maybe_get_oot_by_class(MoERunner)
    return isinstance(source_layer, moe_cls) and len(packed_modules_list) == 2

create_lora_weights(max_loras, lora_config, model_config=None)

Initializes lora matrices.

Source code in vllm/lora/layers/fused_moe.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """Initializes lora matrices."""

    self._verify_ep_fs(lora_config)
    self.max_loras = lora_config.max_loras
    self.fully_sharded = lora_config.fully_sharded_loras

    self.adapter_enabled = torch.tensor(
        [0] * (max_loras + 1), dtype=torch.int, device=self.device
    )

    self._create_lora_a_weights(max_loras, lora_config)
    self._create_lora_b_weights(max_loras, lora_config)
    # They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
    # to create a dummy LoRA weights.
    # TODO Optimize this section
    self.lora_a_stacked = []
    self.lora_b_stacked = []
    for lora_id in range(max_loras):
        for experts_id in range(self.local_num_experts):
            # For gated MoE: gate_proj (w1), down_proj (w2), up_proj (w3)
            # For non-gated MoE: up_proj (w1), down_proj (w2)
            self.lora_a_stacked.append(
                self.w13_lora_a_stacked[0][lora_id][experts_id]
            )
            self.lora_a_stacked.append(
                self.w2_lora_a_stacked[0][lora_id][experts_id]
            )

            self.lora_b_stacked.append(
                self.w13_lora_b_stacked[0][lora_id][experts_id]
            )
            self.lora_b_stacked.append(
                self.w2_lora_b_stacked[0][lora_id][experts_id]
            )

            # Only add w3 (up_proj) for gated MoE (_w13_slices == 2)
            if self._w13_slices == 2:
                self.lora_a_stacked.append(
                    self.w13_lora_a_stacked[1][lora_id][experts_id]
                )
                self.lora_b_stacked.append(
                    self.w13_lora_b_stacked[1][lora_id][experts_id]
                )

reset_lora(index)

Resets the lora weights at index back to 0.

Source code in vllm/lora/layers/fused_moe.py
def reset_lora(self, index: int):
    """Resets the lora weights at index back to 0."""
    for pos in range(self._w13_slices):
        self.w13_lora_a_stacked[pos][index] = 0
        self.w13_lora_b_stacked[pos][index] = 0

    self.w2_lora_a_stacked[0][index] = 0
    self.w2_lora_b_stacked[0][index] = 0
    self.adapter_enabled[index] = 0

set_lora(index, lora_a, lora_b)

Overwrites lora tensors at index.

Source code in vllm/lora/layers/fused_moe.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor | list[torch.Tensor],
    lora_b: torch.Tensor | list[torch.Tensor],
):
    """Overwrites lora tensors at index."""
    # Make mypy happy
    assert isinstance(lora_a, list)
    assert isinstance(lora_b, list)

    self.reset_lora(index)
    self.adapter_enabled[index] = 1

    num_experts = self.w13_lora_a_stacked[0].shape[1]

    w1_lora_a, w2_lora_a, w3_lora_a = lora_a
    w1_lora_b, w2_lora_b, w3_lora_b = lora_b

    # EP slicing is done once at add time in
    # LoRAModelManager._slice_moe_lora_ep, so by here the cached
    # tensors already match the local-expert dim of the stacked buffers.
    assert (
        num_experts
        == w1_lora_a.shape[0]
        == w2_lora_a.shape[0]
        == w3_lora_a.shape[0]
    )

    slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
    slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)

    sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
    sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

    self.w13_lora_a_stacked[0][
        index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
    ].copy_(slliced_w1_lora_a, non_blocking=True)

    self.w13_lora_b_stacked[0][
        index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
    ].copy_(slliced_w1_lora_b, non_blocking=True)

    # Only copy w3 (up_proj) for gated MoE (_w13_slices == 2)
    if self._w13_slices == 2:
        slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
        slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)

        self.w13_lora_a_stacked[1][
            index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
        ].copy_(slliced_w3_lora_a, non_blocking=True)

        self.w13_lora_b_stacked[1][
            index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
        ].copy_(slliced_w3_lora_b, non_blocking=True)

    self.w2_lora_a_stacked[0][
        index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
    ].copy_(sliced_w2_lora_a, non_blocking=True)

    self.w2_lora_b_stacked[0][
        index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
    ].copy_(sliced_w2_lora_b, non_blocking=True)

LogitsProcessorWithLoRA

Bases: BaseLayerWithLoRA

LoRA wrapper for LogitsProcessor, with extra logic to handle the application of the LoRA adapter and added LoRA vocabulary.

Parameters:

  • base_layer

    (LogitsProcessor) –

    LogitsProcessor layer

  • hidden_size

    (int) –

    hidden size of the model

  • dtype

    (dtype) –

    data type of the model

  • device

    (device) –

    device of the model

  • sharded_to_full_mapping

    (list[int] | None) –

    index mapping from sharded vocab to full vocab received from base_layer.get_sharded_to_full_mapping(). If None, no reindexing will be done.

Source code in vllm/lora/layers/logits_processor.py
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
    """
    LoRA wrapper for LogitsProcessor, with extra logic to handle the
    application of the LoRA adapter and added LoRA vocabulary.

    Args:
        base_layer: LogitsProcessor layer
        hidden_size: hidden size of the model
        dtype: data type of the model
        device: device of the model
        sharded_to_full_mapping: index mapping from sharded vocab to full vocab
            received from base_layer.get_sharded_to_full_mapping(). If None,
            no reindexing will be done.
    """

    def __init__(
        self,
        base_layer: LogitsProcessor,
        hidden_size: int,
        dtype: torch.dtype,
        device: torch.device,
        sharded_to_full_mapping: list[int] | None,
    ) -> None:
        super().__init__()
        self.base_layer = base_layer
        self.hidden_size = hidden_size
        self.dtype = dtype
        self.device = device
        self.tp_size = get_tensor_model_parallel_world_size()
        self.tp_rank = get_tensor_model_parallel_rank()
        self.sharded_to_full_mapping = sharded_to_full_mapping

    @property
    def logits_as_input(self):
        return self.base_layer.logits_as_input

    @property
    def vocab_size(self):
        return self.base_layer.vocab_size

    @property
    def scale(self):
        return self.base_layer.scale

    @property
    def soft_cap(self):
        return self.base_layer.soft_cap

    @property
    def use_all_gather(self):
        return self.base_layer.use_all_gather

    @property
    def org_vocab_size(self):
        return self.base_layer.org_vocab_size

    @property
    def include_gpu_probs_tensor(self):
        return self.base_layer.include_gpu_probs_tensor

    @property
    def should_modify_greedy_probs_inplace(self):
        return self.base_layer.should_modify_greedy_probs_inplace

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        # TODO: Verify if this condition can be further relaxed
        if self.base_layer.vocab_size > 258048:
            raise ValueError("When using LoRA, vocab size must be <= 258048")
        self.lora_a_stacked = torch.zeros(
            (
                max_loras,
                1,
                lora_config.max_lora_rank,
                self.hidden_size,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        self.lora_b_stacked = torch.zeros(
            (
                max_loras,
                1,
                self.base_layer.vocab_size,
                lora_config.max_lora_rank,
            ),
            dtype=lora_config.lora_dtype,
            device=self.device,
        )

        if self.sharded_to_full_mapping is not None:
            self.sharded_to_full_mapping_gpu = torch.tensor(
                self.sharded_to_full_mapping, device=self.device, dtype=torch.long
            )
        else:
            self.sharded_to_full_mapping_gpu = None

    def reset_lora(self, index: int):
        self.lora_a_stacked[index] = 0
        self.lora_b_stacked[index] = 0

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        assert isinstance(lora_a, torch.Tensor)
        assert isinstance(lora_b, torch.Tensor)
        self.reset_lora(index)
        self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
            lora_a, non_blocking=True
        )
        self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
            lora_b, non_blocking=True
        )

    def _get_logits(
        self,
        hidden_states: torch.Tensor,
        lm_head: VocabParallelEmbedding,
        embedding_bias: torch.Tensor | None = None,
    ) -> torch.Tensor | None:
        # Get the logits for the next tokens.
        if hasattr(lm_head, "base_layer"):
            actual_lm_head = lm_head.base_layer
        else:
            actual_lm_head = lm_head
        logits = actual_lm_head.quant_method.apply(actual_lm_head, hidden_states)
        if embedding_bias is not None:
            logits += embedding_bias

        # Gather logits for TP
        logits = self.base_layer._gather_logits(logits)

        if logits is None:
            return None

        if self.sharded_to_full_mapping_gpu is not None:
            # Reindex full logits tensor to ensure 1:1 mapping between
            # index and token_id
            # Example for:
            #   org_vocab_size = 4
            #   added_vocab_size = 2
            #   pad_to_size = 8
            #   tp_size = 2

            # indices:  [0, 1, 2,  3, 4, 5, 6,  7]
            # token_id: [0, 1, 4, -1, 2, 3, 5, -1]

            # Therefore, the mapping is expected to be:
            # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
            # we get:
            # indices:  [0, 1, 2, 3, 4, 5,  6,  7]
            # token_id: [0, 1, 2, 3, 4, 5, -1, -1]
            logits = logits[:, self.sharded_to_full_mapping_gpu]

        lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_logits(
            logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, 1.0
        )

        if not current_platform.can_update_inplace():
            logits = lora_output

        # Remove paddings in vocab (if any).
        logits = logits[:, : self.base_layer.vocab_size]
        return logits

    def forward(self, *args, **kwargs):
        return type(self.base_layer).forward(self, *args, **kwargs)

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # Special handling for the LogitsProcessor.
        return False

MergedColumnParallelLinearVariableSliceWithLoRA

Bases: MergedColumnParallelLinearWithLoRA

MergedColumnParallelLinear with variable number of slices (3+).

This handles cases where the checkpoint has a single weight for the whole module (not split into slices), but the layer itself has multiple slices.

Methods:

  • set_lora

    Override to handle single tensor weights

Source code in vllm/lora/layers/column_parallel_linear.py
class MergedColumnParallelLinearVariableSliceWithLoRA(
    MergedColumnParallelLinearWithLoRA
):
    """MergedColumnParallelLinear with variable number of slices (3+).

    This handles cases where the checkpoint has a single weight for the whole
    module (not split into slices), but the layer itself has multiple slices.
    """

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # Support MergedColumnParallelLinear with 3 or more slices
        # (2 slices are handled by MergedColumnParallelLinearWithLoRA)
        if not isinstance(
            source_layer, maybe_get_oot_by_class(MergedColumnParallelLinear)
        ):
            return False

        # If packed_modules_list has 3+ items, use this class
        if len(packed_modules_list) >= 3:
            return True

        # If packed_modules_list has exactly 2 items, let
        # MergedColumnParallelLinearWithLoRA handle it
        if len(packed_modules_list) == 2:
            return False

        # If packed_modules_list is empty or has 1 item,
        # check the layer's output_sizes.
        # This handles cases where the checkpoint has a single weight
        # but the layer has multiple slices (3+)
        return (
            hasattr(source_layer, "output_sizes")
            and len(source_layer.output_sizes) >= 3
        )

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        """Override to handle single tensor weights
        that need to be split into slices."""
        self.reset_lora(index)

        # Handle case where checkpoint has single tensor weights
        # lora_a shape: (rank, input_size) - same for all slices, duplicate it
        if isinstance(lora_a, torch.Tensor):
            lora_a = [lora_a] * self.n_slices

        # lora_b shape: (total_output_size, rank) -
        # split along dim 0 based on output_sizes
        if isinstance(lora_b, torch.Tensor):
            output_sizes = self.base_layer.output_sizes
            lora_b_list = []
            start_idx = 0
            for output_size in output_sizes:
                end_idx = start_idx + output_size
                lora_b_list.append(lora_b[start_idx:end_idx, :])
                start_idx = end_idx
            lora_b = lora_b_list

        # Now call parent's set_lora which expects lists
        super().set_lora(index, lora_a, lora_b)

set_lora(index, lora_a, lora_b)

Override to handle single tensor weights that need to be split into slices.

Source code in vllm/lora/layers/column_parallel_linear.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor | list[torch.Tensor],
    lora_b: torch.Tensor | list[torch.Tensor],
):
    """Override to handle single tensor weights
    that need to be split into slices."""
    self.reset_lora(index)

    # Handle case where checkpoint has single tensor weights
    # lora_a shape: (rank, input_size) - same for all slices, duplicate it
    if isinstance(lora_a, torch.Tensor):
        lora_a = [lora_a] * self.n_slices

    # lora_b shape: (total_output_size, rank) -
    # split along dim 0 based on output_sizes
    if isinstance(lora_b, torch.Tensor):
        output_sizes = self.base_layer.output_sizes
        lora_b_list = []
        start_idx = 0
        for output_size in output_sizes:
            end_idx = start_idx + output_size
            lora_b_list.append(lora_b[start_idx:end_idx, :])
            start_idx = end_idx
        lora_b = lora_b_list

    # Now call parent's set_lora which expects lists
    super().set_lora(index, lora_a, lora_b)

MergedColumnParallelLinearWithLoRA

Bases: ColumnParallelLinearWithLoRA

ColumnParallelLinear layer that is composed of 2 sublayers (slices) packed together (e.g. gate_proj + up_proj -> gate_up_proj).

This means we have 2 LoRAs, each applied to one half of the layer.

Both slices must have the same size.

Methods:

Source code in vllm/lora/layers/column_parallel_linear.py
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
    """ColumnParallelLinear layer that is composed of 2 sublayers (slices)
    packed together (e.g. gate_proj + up_proj -> gate_up_proj).

    This means we have 2 LoRAs, each applied to one half of the layer.

    Both slices must have the same size.
    """

    def __init__(
        self, base_layer: MergedColumnParallelLinear | QKVParallelLinear
    ) -> None:
        super().__init__(base_layer)
        # There are two LoRA layers
        # the output_sizes in MergedColumnParallelLinear is not sharded by tp
        # we need to divide it by the tp_size to get correct slices size
        self.output_sizes = self.base_layer.output_sizes
        self.output_slices = tuple(
            divide(output_size, self.tp_size) for output_size in self.output_sizes
        )
        self.n_slices = len(self.output_slices)
        self.output_ids = (self.tp_rank,) * self.n_slices

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """
        The main reason for overriding this function is to enhance  code
        maintainability.
        """
        self.lora_config = lora_config

        lora_a_output_size_per_partition = (
            lora_config.max_lora_rank
            if not lora_config.fully_sharded_loras
            else divide(lora_config.max_lora_rank, self.tp_size)
        )

        self.lora_a_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                lora_a_output_size_per_partition,
                self.input_size,
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self.n_slices)
        )
        self.lora_b_stacked = tuple(
            torch.zeros(
                max_loras,
                1,
                output_size,
                lora_config.max_lora_rank,
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for output_size in self.output_slices
        )

    def slice_lora_a(
        self, lora_a: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]:
        return lora_a

    def slice_lora_b(
        self, lora_b: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]:
        sliced_lora_b = [None] * self.n_slices
        for i, (shard_id, shard_size) in enumerate(
            zip(self.output_ids, self.output_slices)
        ):
            if (lora_b_i := lora_b[i]) is not None:
                sliced_lora_b[i] = lora_b_i[
                    shard_size * shard_id : shard_size * (shard_id + 1), :
                ]
        return sliced_lora_b

    def expand_packed_lora(
        self,
        lora_a: list[torch.Tensor],
        lora_b: list[torch.Tensor],
    ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
        """
        Expand packed adapter groups when they don't match n_slices.
        E.g. in_proj_qkv (covers Q+K+V) + in_proj_z
        """
        expanded_a: list[torch.Tensor] = []
        expanded_b: list[torch.Tensor] = []
        start_idx = 0
        for a_i, b_i in zip(lora_a, lora_b):
            # Determine which output slices this b_i covers.
            b_rows, cu_rows, covered = b_i.shape[0], 0, 0
            for i in range(start_idx, self.n_slices):
                cu_rows += self.output_sizes[i]
                if cu_rows == b_rows:
                    covered = i - start_idx + 1
                    break
            else:
                raise ValueError(
                    f"Cannot determine how to split lora_b with {b_rows} rows "
                    f"into {self.n_slices} slices with output sizes "
                    f"{self.output_sizes} starting from index {start_idx}."
                )
            # Split b_i into per-slice tensors and replicate a_i for each.
            start = 0
            for j in range(covered):
                size = self.output_sizes[start_idx + j]
                expanded_b.append(b_i[start : start + size, :])
                expanded_a.append(a_i)
                start += size
            start_idx += covered
        return expanded_a, expanded_b

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        self.reset_lora(index)

        # Expand packed adapter groups when they don't match n_slices.
        # E.g. in_proj_qkv (covers Q+K+V) + in_proj_z as 2 groups for a
        # 4-slice layer: split b_qkv by output_sizes and replicate a_qkv.
        if isinstance(lora_b, list) and len(lora_b) != self.n_slices:
            lora_a, lora_b = self.expand_packed_lora(lora_a, lora_b)

        if self.tp_size > 1:
            lora_a = self.slice_lora_a(lora_a)
            lora_b = self.slice_lora_b(lora_b)

        for i in range(self.n_slices):
            if (lora_a_i := lora_a[i]) is not None:
                self.lora_a_stacked[i][
                    index, 0, : lora_a_i.shape[0], : lora_a_i.shape[1]
                ].copy_(lora_a_i, non_blocking=True)
            if (lora_b_i := lora_b[i]) is not None:
                self.lora_b_stacked[i][
                    index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1]
                ].copy_(lora_b_i, non_blocking=True)

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        merged_cls = maybe_get_oot_by_class(MergedColumnParallelLinear)
        # Effectively unsharded subclasses can safely reuse their custom
        # forward() implementation before applying the LoRA delta.
        if (
            self.tp_size == 1
            and type(self.base_layer) is not merged_cls
            and type(self.base_layer).forward is not merged_cls.forward
        ):
            return self._apply_base_forward(x)
        return _mcp_apply(x, bias, self)

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
        decorate: bool = True,
    ) -> bool:
        merged_cls = maybe_get_oot_by_class(MergedColumnParallelLinear)
        if not isinstance(source_layer, merged_cls) or len(packed_modules_list) != 2:
            return False

        tp_size = getattr(source_layer, "tp_size", 1)
        if type(source_layer) is merged_cls:
            if not decorate:
                return True
            return not lora_config.fully_sharded_loras or tp_size == 1

        # Only support effectively unsharded subclasses here. Sharded
        # subclasses may have custom communication semantics that the generic
        # merged-column LoRA path does not know how to preserve.
        return tp_size == 1

create_lora_weights(max_loras, lora_config, model_config=None)

The main reason for overriding this function is to enhance code maintainability.

Source code in vllm/lora/layers/column_parallel_linear.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """
    The main reason for overriding this function is to enhance  code
    maintainability.
    """
    self.lora_config = lora_config

    lora_a_output_size_per_partition = (
        lora_config.max_lora_rank
        if not lora_config.fully_sharded_loras
        else divide(lora_config.max_lora_rank, self.tp_size)
    )

    self.lora_a_stacked = tuple(
        torch.zeros(
            max_loras,
            1,
            lora_a_output_size_per_partition,
            self.input_size,
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        for _ in range(self.n_slices)
    )
    self.lora_b_stacked = tuple(
        torch.zeros(
            max_loras,
            1,
            output_size,
            lora_config.max_lora_rank,
            dtype=lora_config.lora_dtype,
            device=self.device,
        )
        for output_size in self.output_slices
    )

expand_packed_lora(lora_a, lora_b)

Expand packed adapter groups when they don't match n_slices. E.g. in_proj_qkv (covers Q+K+V) + in_proj_z

Source code in vllm/lora/layers/column_parallel_linear.py
def expand_packed_lora(
    self,
    lora_a: list[torch.Tensor],
    lora_b: list[torch.Tensor],
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
    """
    Expand packed adapter groups when they don't match n_slices.
    E.g. in_proj_qkv (covers Q+K+V) + in_proj_z
    """
    expanded_a: list[torch.Tensor] = []
    expanded_b: list[torch.Tensor] = []
    start_idx = 0
    for a_i, b_i in zip(lora_a, lora_b):
        # Determine which output slices this b_i covers.
        b_rows, cu_rows, covered = b_i.shape[0], 0, 0
        for i in range(start_idx, self.n_slices):
            cu_rows += self.output_sizes[i]
            if cu_rows == b_rows:
                covered = i - start_idx + 1
                break
        else:
            raise ValueError(
                f"Cannot determine how to split lora_b with {b_rows} rows "
                f"into {self.n_slices} slices with output sizes "
                f"{self.output_sizes} starting from index {start_idx}."
            )
        # Split b_i into per-slice tensors and replicate a_i for each.
        start = 0
        for j in range(covered):
            size = self.output_sizes[start_idx + j]
            expanded_b.append(b_i[start : start + size, :])
            expanded_a.append(a_i)
            start += size
        start_idx += covered
    return expanded_a, expanded_b

MergedColumnParallelLinearWithShardedLoRA

Bases: MergedColumnParallelLinearWithLoRA

Differs from MergedColumnParallelLinearWithLoRA by slicing the LoRA A's also.

Based on S-LoRA, slicing happens along the rank dim.

Source code in vllm/lora/layers/column_parallel_linear.py
class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLoRA):
    """
    Differs from MergedColumnParallelLinearWithLoRA by slicing the
    LoRA A's also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

    def slice_lora_a(
        self, lora_a: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]:
        output_shard_size = self.lora_a_stacked[0].shape[2]
        output_start_idx = self.tp_rank * output_shard_size
        return [
            lora_a_i[output_start_idx : output_start_idx + output_shard_size, :]
            if (lora_a_i := lora_a[i]) is not None
            else None
            for i in range(len(lora_a))
        ]

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        return _mcp_apply(x, bias, self)

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )

MergedQKVParallelLinearWithLoRA

Bases: MergedColumnParallelLinearWithLoRA

MergedColumnParallelLinear layer that is composed of 3 sublayers (slices) packed together in qkv proj fashion (q_proj + k_proj + v_proj -> qkv_proj).

This means we have 3 LoRAs, each applied to one slice of the layer.

Q slice may have different shape than K and V slices (which both have the same shape).

Methods:

  • create_lora_weights

    The main reason for overloading this function is to handle inconsistent

Source code in vllm/lora/layers/column_parallel_linear.py
class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
    """MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
    packed together in qkv proj fashion
    (q_proj + k_proj + v_proj -> qkv_proj).

    This means we have 3 LoRAs, each applied to one slice of the layer.

    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
        # There are three LoRA layer.
        self.n_slices = len(self.base_layer.output_sizes)

        self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size
        self.kv_proj_shard_size = (
            self.base_layer.num_kv_heads * self.base_layer.head_size
        )
        self.q_shard_id = self.tp_rank
        self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas

        self.output_slices = (
            self.q_proj_shard_size,
            self.kv_proj_shard_size,
            self.kv_proj_shard_size,
        )
        self.output_ids = (
            self.q_shard_id,
            self.kv_shard_id,
            self.kv_shard_id,
        )

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """
        The main reason for overloading this function is to handle inconsistent
        weight dimensions in qkv lora.
        """
        super().create_lora_weights(max_loras, lora_config, model_config)

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        return (
            type(source_layer) is maybe_get_oot_by_class(QKVParallelLinear)
            and len(packed_modules_list) == 3
        )

create_lora_weights(max_loras, lora_config, model_config=None)

The main reason for overloading this function is to handle inconsistent weight dimensions in qkv lora.

Source code in vllm/lora/layers/column_parallel_linear.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """
    The main reason for overloading this function is to handle inconsistent
    weight dimensions in qkv lora.
    """
    super().create_lora_weights(max_loras, lora_config, model_config)

MergedQKVParallelLinearWithShardedLoRA

Bases: MergedQKVParallelLinearWithLoRA

Differs from MergedQKVParallelLinearWithLoRA by slicing the LoRA A's also.

Based on S-LoRA, slicing happens along the rank dim.

Source code in vllm/lora/layers/column_parallel_linear.py
class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
    """
    Differs from MergedQKVParallelLinearWithLoRA by slicing the
    LoRA A's also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

    def slice_lora_a(
        self, lora_a: list[torch.Tensor | None]
    ) -> list[torch.Tensor | None]:
        # NOTE: lora_a contains 3 subloras, and each sublora could be None.
        shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
        start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
        lora_a = [
            lora_a[0][start_idx[0] : start_idx[0] + shard_size[0], :]
            if lora_a[0] is not None
            else None,
            lora_a[1][start_idx[1] : start_idx[1] + shard_size[1], :]
            if lora_a[1] is not None
            else None,
            lora_a[2][start_idx[2] : start_idx[2] + shard_size[2], :]
            if lora_a[2] is not None
            else None,
        ]
        return lora_a

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        return _mcp_apply(x, bias, self)

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )

QKVParallelLinearWithLoRA

Bases: ColumnParallelLinearWithLoRA

ColumnParallelLinear layer that is specifically designed for qkv_proj. Certain models, such as chatglm3 and baichuan-7b, only contains a single LoRA within their qkv_proj layer.

During inference with Tensor Parallel, the weights of lora_b must be accurately partitioned according to the respective ranks.

Q slice may have different shape than K and V slices (which both have the same shape).

Source code in vllm/lora/layers/column_parallel_linear.py
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
    """
    ColumnParallelLinear layer that is specifically designed for
    qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
    only contains a single LoRA within their qkv_proj layer.

    During inference with Tensor Parallel, the weights of lora_b
    must be accurately partitioned according to the respective ranks.

    Q slice may have different shape than K and V slices (which both have
    the same shape).
    """

    def __init__(self, base_layer: QKVParallelLinear) -> None:
        super().__init__(base_layer)
        self.q_proj_total_size = (
            self.base_layer.total_num_heads * self.base_layer.head_size
        )
        self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size
        self.kv_proj_shard_size = (
            self.base_layer.num_kv_heads * self.base_layer.head_size
        )
        self.kv_proj_total_size = (
            self.base_layer.total_num_kv_heads * self.base_layer.head_size
        )
        # There is only one LoRA layer
        self.n_slices = 1

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        self.q_shard_id = self.tp_rank
        self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
        lora_b_q = lora_b[
            self.q_proj_shard_size * self.q_shard_id : self.q_proj_shard_size
            * (self.q_shard_id + 1),
            :,
        ]
        k_offset = self.q_proj_total_size
        lora_b_k = lora_b[
            k_offset + self.kv_proj_shard_size * self.kv_shard_id : k_offset
            + self.kv_proj_shard_size * (self.kv_shard_id + 1),
            :,
        ]
        v_offset = k_offset + self.kv_proj_total_size
        lora_b_v = lora_b[
            v_offset + self.kv_proj_shard_size * self.kv_shard_id : v_offset
            + self.kv_proj_shard_size * (self.kv_shard_id + 1),
            :,
        ]
        lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
        return lora_b

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        return (
            type(source_layer) is maybe_get_oot_by_class(QKVParallelLinear)
            and len(packed_modules_list) == 1
        )

QKVParallelLinearWithShardedLoRA

Bases: QKVParallelLinearWithLoRA

Differs from QKVParallelLinearWithLoRA by slicing the LoRA A's also.

Based on S-LoRA, slicing happens along the rank dim.

Source code in vllm/lora/layers/column_parallel_linear.py
class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
    """
    Differs from QKVParallelLinearWithLoRA by slicing the
    LoRA A's also.

    Based on S-LoRA, slicing happens along the rank dim.
    """

    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        shard_size = self.lora_a_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        lora_a = lora_a[start_idx : start_idx + shard_size, :]
        return lora_a

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        return _mcp_apply(x, bias, self)

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )

ReplicatedLinearWithLoRA

Bases: BaseLinearLayerWithLoRA

Methods:

  • forward

    Forward of ReplicatedLinearWithLoRA

  • slice_lora_a

    Slice lora a if splitting for tensor parallelism.

  • slice_lora_b

    Slice lora b if splitting with tensor parallelism.

Source code in vllm/lora/layers/replicated_linear.py
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
    def __init__(self, base_layer: ReplicatedLinear) -> None:
        super().__init__(
            base_layer,
        )
        # To ensure interface compatibility, set to 1 always.
        self.output_size = self.base_layer.output_size
        self.n_slices = 1

    def forward(
        self, input_: torch.Tensor
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
        """Forward of ReplicatedLinearWithLoRA

        Args:
            input_: Tensor whose last dimension is `input_size`.

        Returns:
            - output
            - bias
        """
        bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None

        # Matrix multiply.
        output = self.apply(input_, bias)

        output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None

        if not self.base_layer.return_bias:
            return output

        return output, output_bias

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        # ReplicatedLinear subclasses such as GateLinear override forward() to
        # dispatch custom kernels and/or adjust the output dtype. Apply LoRA on
        # top of the actual base-layer output instead of bypassing that path.
        return self._apply_base_forward(x)

    # ReplicatedLinear should always be replaced, regardless of the fully
    # sharded LoRAs setting, because it is, by definition, copied per GPU.
    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        return isinstance(source_layer, maybe_get_oot_by_class(ReplicatedLinear))

    def slice_lora_a(
        self, lora_a: torch.Tensor | list[torch.Tensor | None]
    ) -> torch.Tensor | list[torch.Tensor | None]:
        """Slice lora a if splitting for tensor parallelism."""
        return lora_a

    def slice_lora_b(
        self, lora_b: torch.Tensor | list[torch.Tensor | None]
    ) -> torch.Tensor | list[torch.Tensor | None]:
        """Slice lora b if splitting with tensor parallelism."""
        return lora_b

forward(input_)

Forward of ReplicatedLinearWithLoRA

Parameters:

  • input_

    (Tensor) –

    Tensor whose last dimension is input_size.

Returns:

Source code in vllm/lora/layers/replicated_linear.py
def forward(
    self, input_: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
    """Forward of ReplicatedLinearWithLoRA

    Args:
        input_: Tensor whose last dimension is `input_size`.

    Returns:
        - output
        - bias
    """
    bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None

    # Matrix multiply.
    output = self.apply(input_, bias)

    output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None

    if not self.base_layer.return_bias:
        return output

    return output, output_bias

slice_lora_a(lora_a)

Slice lora a if splitting for tensor parallelism.

Source code in vllm/lora/layers/replicated_linear.py
def slice_lora_a(
    self, lora_a: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
    """Slice lora a if splitting for tensor parallelism."""
    return lora_a

slice_lora_b(lora_b)

Slice lora b if splitting with tensor parallelism.

Source code in vllm/lora/layers/replicated_linear.py
def slice_lora_b(
    self, lora_b: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
    """Slice lora b if splitting with tensor parallelism."""
    return lora_b

RowParallelLinearWithLoRA

Bases: BaseLinearLayerWithLoRA

Methods:

  • forward

    Forward of RowParallelLinear

Source code in vllm/lora/layers/row_parallel_linear.py
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
    def __init__(self, base_layer: RowParallelLinear) -> None:
        super().__init__(base_layer)

        # reset input_size
        self.input_size = self.base_layer.input_size_per_partition
        self.output_size = self.base_layer.output_size
        # There is only one LoRA layer.
        self.n_slices = 1

    def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
        shard_size = self.input_size
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        lora_a = lora_a[:, start_idx:end_idx]
        return lora_a

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        return lora_b

    def forward(
        self, input_: torch.Tensor
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
        """Forward of RowParallelLinear

        Args:
            input_: tensor whose last dimension is `input_size`. If
                    `input_is_parallel` is set, then the last dimension
                    is `input_size // tp_size`.

        Returns:
            - output
            - bias
        """
        # set up backprop all-reduce.
        if self.base_layer.input_is_parallel:
            input_parallel = input_
        else:
            # TODO: simplify code below
            split_input = split_tensor_along_last_dim(
                input_, num_partitions=self.tp_size
            )
            input_parallel = split_input[self.tp_rank].contiguous()

        # Matrix multiply.
        bias_ = (
            None
            if (self.tp_rank > 0 or self.base_layer.skip_bias_add)
            else self.base_layer.bias
        )
        output_parallel = self.apply(input_parallel, bias_)
        if self.base_layer.reduce_results and self.tp_size > 1:
            output = tensor_model_parallel_all_reduce(output_parallel)
        else:
            output = output_parallel

        output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
        if not self.base_layer.return_bias:
            return output

        return output, output_bias

    @classmethod
    @_not_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        return type(source_layer) is maybe_get_oot_by_class(RowParallelLinear)

forward(input_)

Forward of RowParallelLinear

Parameters:

  • input_

    (Tensor) –

    tensor whose last dimension is input_size. If input_is_parallel is set, then the last dimension is input_size // tp_size.

Returns:

Source code in vllm/lora/layers/row_parallel_linear.py
def forward(
    self, input_: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
    """Forward of RowParallelLinear

    Args:
        input_: tensor whose last dimension is `input_size`. If
                `input_is_parallel` is set, then the last dimension
                is `input_size // tp_size`.

    Returns:
        - output
        - bias
    """
    # set up backprop all-reduce.
    if self.base_layer.input_is_parallel:
        input_parallel = input_
    else:
        # TODO: simplify code below
        split_input = split_tensor_along_last_dim(
            input_, num_partitions=self.tp_size
        )
        input_parallel = split_input[self.tp_rank].contiguous()

    # Matrix multiply.
    bias_ = (
        None
        if (self.tp_rank > 0 or self.base_layer.skip_bias_add)
        else self.base_layer.bias
    )
    output_parallel = self.apply(input_parallel, bias_)
    if self.base_layer.reduce_results and self.tp_size > 1:
        output = tensor_model_parallel_all_reduce(output_parallel)
    else:
        output = output_parallel

    output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
    if not self.base_layer.return_bias:
        return output

    return output, output_bias

RowParallelLinearWithShardedLoRA

Bases: RowParallelLinearWithLoRA

Differs from RowParallelLinearWithLoRA by slicing the LoRA B's also.

Based on S-LoRA, slicing happens along the output dim. This yields a combined partial sum from the row parallel base layer and column partitioned output from the LoRA.

Source code in vllm/lora/layers/row_parallel_linear.py
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
    """
    Differs from RowParallelLinearWithLoRA by slicing the
    LoRA B's also.

    Based on S-LoRA, slicing happens along the output dim.
    This yields a combined partial sum from the row parallel base
    layer and column partitioned output from the LoRA.
    """

    def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
        shard_size = self.lora_b_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        lora_b = lora_b[start_idx:end_idx, :]
        return lora_b

    def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)

        x = x.view(-1, x.shape[-1])
        output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
        buffer = torch.zeros(
            (self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
            dtype=torch.float32,
            device=x.device,
        )

        shrunk_buffer: torch.Tensor | None = self.punica_wrapper.add_shrink(
            buffer, x, self.lora_a_stacked, 1.0
        )
        if not current_platform.can_update_inplace():
            buffer = shrunk_buffer
        if self.tp_size > 1:
            buffer = tensor_model_parallel_all_reduce(buffer)

        # following S-LoRA, allows the fusing of all_gather and all_reduce
        # by adding the column partitioned lora output to a slice of output
        # tensor, which is a partial sum due to row parallel. All that
        # remains is a standard all_reduce. User should be aware though that
        # the output is not the same as a normal row_parallel, it should be
        # reduced before being used
        # NOTE offset are based on the rank.
        shard_size = self.lora_b_stacked[0].shape[2]
        offset_start = self.tp_rank * shard_size
        lora_output: torch.Tensor | None = self.punica_wrapper.add_expand(
            output,
            buffer,
            self.lora_b_stacked,
            self.output_slices,
            offset_start=offset_start,
            add_input=True,
        )

        if not current_platform.can_update_inplace():
            output = lora_output

        output = output.view(*out_orig_shape)
        return output

    @classmethod
    @_fully_sharded_can_replace
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        # specifying kwargs so they can be easily accessed in decorator
        return super().can_replace_layer(
            source_layer=source_layer,
            lora_config=lora_config,
            packed_modules_list=packed_modules_list,
            model_config=model_config,
            decorate=False,
        )