Skip to content

vllm.lora.layers.fused_moe

Classes:

FusedMoE3DWithLoRA

Bases: FusedMoEWithLoRA

Methods:

Attributes:

Source code in vllm/lora/layers/fused_moe.py
class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
    def __init__(self, base_layer: MoERunner):
        super().__init__(base_layer)
        self._w13_slices = 1

    def _create_lora_b_weights(self, max_loras, lora_config):
        self.w13_lora_b_stacked: tuple[torch.Tensor] = tuple(
            torch.zeros(
                (
                    max_loras,
                    self.local_num_experts,
                    self.intermediate_size_per_partition * 2,
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self._w13_slices)
        )
        self.w2_lora_b_stacked: tuple[torch.Tensor] = (
            torch.zeros(
                (
                    max_loras,
                    self.local_num_experts,
                    self.hidden_size
                    if not self.fully_sharded
                    else divide(self.hidden_size, self.tp_size),
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            ),
        )

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """Initializes lora matrices."""

        assert isinstance(model_config, PretrainedConfig)
        self._verify_ep_fs(lora_config)
        self._base_model = model_config.architectures[0]
        self.max_loras = lora_config.max_loras
        self.fully_sharded = lora_config.fully_sharded_loras

        self.adapter_enabled = torch.tensor(
            [0] * (max_loras + 1), dtype=torch.int, device=self.device
        )

        self._create_lora_a_weights(max_loras, lora_config)
        self._create_lora_b_weights(max_loras, lora_config)

    def _slice_w13_b(self, w13_lora_b: torch.Tensor):
        if self.tp_size == 1:
            return w13_lora_b

        # w13_lora_b shape (num_experts,output_size,rank)
        shard_size = self.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        # HACK: Currently, only GPT-OSS is in interleaved order
        if self._base_model == "GptOssForCausalLM":
            # For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
            # in the interleaved order, and corresponding LoRA need to be processed.
            w1_lora_b = w13_lora_b[:, ::2, :]
            w3_lora_b = w13_lora_b[:, 1::2, :]
            sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
            sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]

            return torch.stack([sliced_w1_lora_b, sliced_w3_lora_b], dim=2).flatten(
                1, 2
            )
        else:
            slice_size = w13_lora_b.shape[1] // 2
            w1_lora_b = w13_lora_b[:, :slice_size, :]
            w3_lora_b = w13_lora_b[:, slice_size:, :]
            sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
            sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]

            return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        """Overwrites lora tensors at index."""
        # Make mypy happy
        assert isinstance(lora_a, list)
        assert isinstance(lora_b, list)
        assert len(lora_a) == len(lora_b) == 2

        self.reset_lora(index)
        self.adapter_enabled[index] = 1

        w13_lora_a, w2_lora_a = lora_a
        w13_lora_b, w2_lora_b = lora_b

        sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
        sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)

        sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
        sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

        self.w13_lora_a_stacked[0][
            index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
        ].copy_(sliced_w13_lora_a, non_blocking=True)
        self.w2_lora_a_stacked[0][
            index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
        ].copy_(sliced_w2_lora_a, non_blocking=True)

        self.w13_lora_b_stacked[0][
            index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
        ].copy_(sliced_w13_lora_b, non_blocking=True)
        self.w2_lora_b_stacked[0][
            index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
        ].copy_(sliced_w2_lora_b, non_blocking=True)

    @property
    def w13_input_size(self):
        """
        Full size
        """
        return self.w13_lora_a_stacked[0].shape[-1]

    @property
    def w13_output_size(self):
        """
        Full size
        """
        return self.w13_lora_b_stacked[0].shape[-2] * self.tp_size

    @property
    def w2_input_size(self):
        """
        Full size
        """
        return self.w2_lora_a_stacked[0].shape[-1] * self.tp_size

    @property
    def w2_output_size(self):
        """
        Full size
        """
        return self.hidden_size

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""
        # source_layer is MoERunner
        moe_cls = maybe_get_oot_by_class(MoERunner)
        return isinstance(source_layer, moe_cls) and len(packed_modules_list) == 1

w13_input_size property

Full size

w13_output_size property

Full size

w2_input_size property

Full size

w2_output_size property

Full size

can_replace_layer(source_layer, lora_config, packed_modules_list, model_config=None) classmethod

Returns True if the layer can be replaced by this LoRA layer.

Source code in vllm/lora/layers/fused_moe.py
@classmethod
def can_replace_layer(
    cls,
    source_layer: nn.Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool:
    """Returns True if the layer can be replaced by this LoRA layer."""
    # source_layer is MoERunner
    moe_cls = maybe_get_oot_by_class(MoERunner)
    return isinstance(source_layer, moe_cls) and len(packed_modules_list) == 1

create_lora_weights(max_loras, lora_config, model_config=None)

Initializes lora matrices.

Source code in vllm/lora/layers/fused_moe.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """Initializes lora matrices."""

    assert isinstance(model_config, PretrainedConfig)
    self._verify_ep_fs(lora_config)
    self._base_model = model_config.architectures[0]
    self.max_loras = lora_config.max_loras
    self.fully_sharded = lora_config.fully_sharded_loras

    self.adapter_enabled = torch.tensor(
        [0] * (max_loras + 1), dtype=torch.int, device=self.device
    )

    self._create_lora_a_weights(max_loras, lora_config)
    self._create_lora_b_weights(max_loras, lora_config)

set_lora(index, lora_a, lora_b)

Overwrites lora tensors at index.

Source code in vllm/lora/layers/fused_moe.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor | list[torch.Tensor],
    lora_b: torch.Tensor | list[torch.Tensor],
):
    """Overwrites lora tensors at index."""
    # Make mypy happy
    assert isinstance(lora_a, list)
    assert isinstance(lora_b, list)
    assert len(lora_a) == len(lora_b) == 2

    self.reset_lora(index)
    self.adapter_enabled[index] = 1

    w13_lora_a, w2_lora_a = lora_a
    w13_lora_b, w2_lora_b = lora_b

    sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
    sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)

    sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
    sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

    self.w13_lora_a_stacked[0][
        index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
    ].copy_(sliced_w13_lora_a, non_blocking=True)
    self.w2_lora_a_stacked[0][
        index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
    ].copy_(sliced_w2_lora_a, non_blocking=True)

    self.w13_lora_b_stacked[0][
        index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
    ].copy_(sliced_w13_lora_b, non_blocking=True)
    self.w2_lora_b_stacked[0][
        index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
    ].copy_(sliced_w2_lora_b, non_blocking=True)

FusedMoEWithLoRA

Bases: BaseLayerWithLoRA

Methods:

Source code in vllm/lora/layers/fused_moe.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
class FusedMoEWithLoRA(BaseLayerWithLoRA):
    def __init__(self, base_layer: MoERunner) -> None:
        super().__init__()
        self.base_layer = base_layer
        self.moe_config = base_layer.moe_config
        self._shared_experts = base_layer._shared_experts
        self._ep_check()

        routed_experts = self.base_layer.routed_experts
        assert not routed_experts.quant_method.is_monolithic, (
            "Monolithic kernels are not supported for Fused MoE LoRA."
        )

        # Use the MoE-aware TP rank/size: when EP is active, FusedMoE collapses
        # moe_parallel_config.tp_size to 1 (experts are sharded across the
        # TP group instead).
        moe_parallel_config = self.moe_config.moe_parallel_config
        self.tp_size = moe_parallel_config.tp_size
        self.tp_rank = moe_parallel_config.tp_rank
        self.device = _get_lora_device(base_layer)

        self._enable_aux_cuda_stream = envs.VLLM_LORA_ENABLE_DUAL_STREAM
        self._init_lora_stream_context()
        # For non-gated MoE (is_act_and_mul=False), only 1 slice is needed
        # since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3)
        self._w13_slices = 2 if base_layer.moe_config.is_act_and_mul else 1
        # Mirrors per-(lora_id) layout of `self.lora_a_stacked` (built in
        # `create_lora_weights`) so `create_dummy_lora`'s n_slices fallback
        # matches `lora_a_stacked` length under EP.
        self.n_slices = self.local_num_experts * (self._w13_slices + 1)

        routed_experts._ensure_moe_quant_config_init()
        if getattr(routed_experts.quant_method, "supports_internal_mk", False):
            moe_kernel = routed_experts.quant_method.moe_kernel
        else:
            prepare_finalize = MoEPrepareAndFinalizeNoDPEPModular()
            moe_kernel = FusedMoEKernel(
                prepare_finalize,
                routed_experts.quant_method.select_gemm_impl(
                    prepare_finalize, routed_experts
                ),
            )
        assert moe_kernel.supports_lora(), (
            f"{type(moe_kernel.fused_experts).__name__} does not support LoRA. "
            "For unquantized MoE, set moe_backend='triton' or moe_backend='auto' "
            "(auto selects Triton automatically when LoRA is enabled). "
            "For quantized MoE, mix LoRAExpertsMixin into the experts class "
            "and consume self._lora_context in apply()."
        )
        self._moe_kernel = moe_kernel
        self.base_layer._replace_quant_method(
            FusedMoEModularMethod(self.base_layer._quant_method, moe_kernel)
        )

    @property
    def hidden_size(self) -> int:
        return self.moe_config.hidden_dim

    @property
    def local_num_experts(self) -> int:
        return self.moe_config.num_local_experts

    @property
    def global_num_experts(self) -> int:
        return self.moe_config.num_experts

    @property
    def ep_rank(self) -> int:
        return self.moe_config.moe_parallel_config.ep_rank

    @property
    def use_ep(self) -> bool:
        return self.moe_config.moe_parallel_config.use_ep

    @property
    def intermediate_size_per_partition(self) -> int:
        return self.moe_config.intermediate_size_per_partition

    def _init_lora_stream_context(self) -> None:
        self._lora_stream: torch.cuda.Stream | None = None
        self._events: tuple[torch.cuda.Event, ...] | None = None
        if not self._enable_aux_cuda_stream:
            return
        if not current_platform.is_cuda_alike():
            return
        self._lora_stream = _get_lora_aux_cuda_stream()
        # 4 events: 2 per (base GEMM, LoRA) pair so w13 and w2 don't reuse
        # the same event objects; reuse-within-a-pair is fine because the
        # second pair starts only after intermediate_cache1.add_() has joined.
        self._events = tuple(torch.cuda.Event() for _ in range(4))

    def _build_lora_context(self):
        use_dual_stream = (
            self._enable_aux_cuda_stream
            and not self.fully_sharded
            and self._lora_stream is not None
        )
        return MoELoRAContext(
            w13_lora_a_stacked=self.w13_lora_a_stacked,
            w13_lora_b_stacked=self.w13_lora_b_stacked,
            w2_lora_a_stacked=self.w2_lora_a_stacked,
            w2_lora_b_stacked=self.w2_lora_b_stacked,
            adapter_enabled=self.adapter_enabled,
            max_loras=self.max_loras,
            top_k=self.moe_config.experts_per_token,
            w13_num_slices=self._w13_slices,
            fully_sharded=self.fully_sharded,
            tp_rank=self.tp_rank,
            tp_size=self.tp_size,
            local_num_experts=self.local_num_experts,
            punica_wrapper=self.punica_wrapper,
            use_tuned_config=bool(envs.VLLM_TUNED_CONFIG_FOLDER),
            aux_stream=self._lora_stream if use_dual_stream else None,
            events=self._events if use_dual_stream else None,
        )

    def _create_lora_a_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
    ):
        self.w13_lora_a_stacked: tuple[torch.Tensor, ...] = tuple(
            torch.zeros(
                (
                    max_loras,
                    self.local_num_experts,
                    lora_config.max_lora_rank
                    if not self.fully_sharded
                    else divide(lora_config.max_lora_rank, self.tp_size),
                    self.hidden_size,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self._w13_slices)
        )
        self.w2_lora_a_stacked: tuple[torch.Tensor, ...] = (
            torch.zeros(
                (
                    max_loras,
                    self.local_num_experts,
                    lora_config.max_lora_rank,
                    self.intermediate_size_per_partition,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            ),
        )

    def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig):
        self.w13_lora_b_stacked: tuple[torch.Tensor, ...] = tuple(
            torch.zeros(
                (
                    max_loras,
                    self.local_num_experts,
                    self.intermediate_size_per_partition,
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            )
            for _ in range(self._w13_slices)
        )
        self.w2_lora_b_stacked: tuple[torch.Tensor, ...] = (
            torch.zeros(
                (
                    max_loras,
                    self.local_num_experts,
                    self.hidden_size
                    if not self.fully_sharded
                    else divide(self.hidden_size, self.tp_size),
                    lora_config.max_lora_rank,
                ),
                dtype=lora_config.lora_dtype,
                device=self.device,
            ),
        )

    def _ep_check(self):
        if self.use_ep:
            moe_config = self.moe_config
            all2all_backend = moe_config.moe_parallel_config.all2all_backend
            assert all2all_backend == "allgather_reducescatter", (
                "Fused MoE LoRA with EP currently only supports "
                f"all2all_backend='allgather_reducescatter', got '{all2all_backend}'."
            )
            assert not moe_config.moe_parallel_config.is_sequence_parallel

    def _verify_ep_fs(self, lora_config: LoRAConfig):
        # EP and fully_sharded LoRA both partition along the same TP group —
        # EP on the expert dim, fully_sharded on the LoRA rank dim — with
        # mutually contradictory assumptions about which rank holds which
        # expert's rank-shard.
        assert not (self.use_ep and lora_config.fully_sharded_loras), (
            "Fused MoE LoRA does not support enable_expert_parallel=True "
            "together with fully_sharded_loras=True. Disable one of them."
        )

    def create_lora_weights(
        self,
        max_loras: int,
        lora_config: LoRAConfig,
        model_config: PretrainedConfig | None = None,
    ) -> None:
        """Initializes lora matrices."""

        self._verify_ep_fs(lora_config)
        self.max_loras = lora_config.max_loras
        self.fully_sharded = lora_config.fully_sharded_loras

        self.adapter_enabled = torch.tensor(
            [0] * (max_loras + 1), dtype=torch.int, device=self.device
        )

        self._create_lora_a_weights(max_loras, lora_config)
        self._create_lora_b_weights(max_loras, lora_config)
        # They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
        # to create a dummy LoRA weights.
        # TODO Optimize this section
        self.lora_a_stacked = []
        self.lora_b_stacked = []
        for lora_id in range(max_loras):
            for experts_id in range(self.local_num_experts):
                # For gated MoE: gate_proj (w1), down_proj (w2), up_proj (w3)
                # For non-gated MoE: up_proj (w1), down_proj (w2)
                self.lora_a_stacked.append(
                    self.w13_lora_a_stacked[0][lora_id][experts_id]
                )
                self.lora_a_stacked.append(
                    self.w2_lora_a_stacked[0][lora_id][experts_id]
                )

                self.lora_b_stacked.append(
                    self.w13_lora_b_stacked[0][lora_id][experts_id]
                )
                self.lora_b_stacked.append(
                    self.w2_lora_b_stacked[0][lora_id][experts_id]
                )

                # Only add w3 (up_proj) for gated MoE (_w13_slices == 2)
                if self._w13_slices == 2:
                    self.lora_a_stacked.append(
                        self.w13_lora_a_stacked[1][lora_id][experts_id]
                    )
                    self.lora_b_stacked.append(
                        self.w13_lora_b_stacked[1][lora_id][experts_id]
                    )

    def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
        """
        Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
        """
        if self.tp_size == 1 or not self.fully_sharded:
            return w13_lora_a

        # w13_lora_a shape (num_experts,rank,input_size)
        current_lora_rank = w13_lora_a.shape[1]
        assert current_lora_rank % self.tp_size == 0
        # Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
        shard_size = self.w13_lora_a_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size
        return w13_lora_a[:, start_idx:end_idx, :]

    def _slice_w13_b(self, w13_lora_b: torch.Tensor):
        if self.tp_size == 1:
            return w13_lora_b

        # w13_lora_b shape (num_experts,output_size,rank)
        shard_size = self.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size

        return w13_lora_b[:, start_idx:end_idx, :]

    def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
        """
        Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
        """
        if self.tp_size == 1:
            return w2_lora_a
        # w2_lora_a shape (num_experts,rank,input_size)
        shard_size = self.intermediate_size_per_partition
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size

        return w2_lora_a[:, :, start_idx:end_idx]

    def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
        """
        Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
        """
        if self.tp_size == 1 or not self.fully_sharded:
            return w2_lora_b
        # Based on S-LoRA, we slice W2 B along the hidden_size dim.
        # w2_lora_b shape (num_experts,output_size,rank)
        shard_size = self.w2_lora_b_stacked[0].shape[2]
        start_idx = self.tp_rank * shard_size
        end_idx = (self.tp_rank + 1) * shard_size

        return w2_lora_b[:, start_idx:end_idx, :]

    def reset_lora(self, index: int):
        """Resets the lora weights at index back to 0."""
        for pos in range(self._w13_slices):
            self.w13_lora_a_stacked[pos][index] = 0
            self.w13_lora_b_stacked[pos][index] = 0

        self.w2_lora_a_stacked[0][index] = 0
        self.w2_lora_b_stacked[0][index] = 0
        self.adapter_enabled[index] = 0

    #

    def set_lora(
        self,
        index: int,
        lora_a: torch.Tensor | list[torch.Tensor],
        lora_b: torch.Tensor | list[torch.Tensor],
    ):
        """Overwrites lora tensors at index."""
        # Make mypy happy
        assert isinstance(lora_a, list)
        assert isinstance(lora_b, list)

        self.reset_lora(index)
        self.adapter_enabled[index] = 1

        num_experts = self.w13_lora_a_stacked[0].shape[1]

        w1_lora_a, w2_lora_a, w3_lora_a = lora_a
        w1_lora_b, w2_lora_b, w3_lora_b = lora_b

        # EP slicing is done once at add time in
        # LoRAModelManager._slice_moe_lora_ep, so by here the cached
        # tensors already match the local-expert dim of the stacked buffers.
        assert (
            num_experts
            == w1_lora_a.shape[0]
            == w2_lora_a.shape[0]
            == w3_lora_a.shape[0]
        )

        slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
        slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)

        sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
        sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

        self.w13_lora_a_stacked[0][
            index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
        ].copy_(slliced_w1_lora_a, non_blocking=True)

        self.w13_lora_b_stacked[0][
            index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
        ].copy_(slliced_w1_lora_b, non_blocking=True)

        # Only copy w3 (up_proj) for gated MoE (_w13_slices == 2)
        if self._w13_slices == 2:
            slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
            slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)

            self.w13_lora_a_stacked[1][
                index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
            ].copy_(slliced_w3_lora_a, non_blocking=True)

            self.w13_lora_b_stacked[1][
                index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
            ].copy_(slliced_w3_lora_b, non_blocking=True)

        self.w2_lora_a_stacked[0][
            index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
        ].copy_(sliced_w2_lora_a, non_blocking=True)

        self.w2_lora_b_stacked[0][
            index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
        ].copy_(sliced_w2_lora_b, non_blocking=True)

    def set_mapping(self, punica_wrapper):
        super().set_mapping(punica_wrapper)
        lora_context = self._build_lora_context()
        self._moe_kernel.fused_experts.set_lora_context(lora_context)
        prepare_finalize = self._moe_kernel.prepare_finalize
        if hasattr(prepare_finalize, "set_lora_context"):
            prepare_finalize.set_lora_context(lora_context)

    def forward(self, *args, **kwargs):
        return self.base_layer.forward(*args, **kwargs)

    @property
    def quant_method(self):
        return self.base_layer._quant_method

    @property
    def runner(self) -> MoERunner:
        return self.base_layer

    @property
    def is_internal_router(self) -> bool:
        return self.base_layer.is_internal_router

    @classmethod
    def can_replace_layer(
        cls,
        source_layer: nn.Module,
        lora_config: LoRAConfig,
        packed_modules_list: list,
        model_config: PretrainedConfig | None = None,
    ) -> bool:
        """Returns True if the layer can be replaced by this LoRA layer."""

        # source_layer is MoERunner
        moe_cls = maybe_get_oot_by_class(MoERunner)
        return isinstance(source_layer, moe_cls) and len(packed_modules_list) == 2

_slice_w13_a(w13_lora_a)

Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA

Source code in vllm/lora/layers/fused_moe.py
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
    """
    Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
    """
    if self.tp_size == 1 or not self.fully_sharded:
        return w13_lora_a

    # w13_lora_a shape (num_experts,rank,input_size)
    current_lora_rank = w13_lora_a.shape[1]
    assert current_lora_rank % self.tp_size == 0
    # Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
    shard_size = self.w13_lora_a_stacked[0].shape[2]
    start_idx = self.tp_rank * shard_size
    end_idx = (self.tp_rank + 1) * shard_size
    return w13_lora_a[:, start_idx:end_idx, :]

_slice_w2_a(w2_lora_a)

Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA

Source code in vllm/lora/layers/fused_moe.py
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
    """
    Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
    """
    if self.tp_size == 1:
        return w2_lora_a
    # w2_lora_a shape (num_experts,rank,input_size)
    shard_size = self.intermediate_size_per_partition
    start_idx = self.tp_rank * shard_size
    end_idx = (self.tp_rank + 1) * shard_size

    return w2_lora_a[:, :, start_idx:end_idx]

_slice_w2_b(w2_lora_b)

Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA

Source code in vllm/lora/layers/fused_moe.py
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
    """
    Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
    """
    if self.tp_size == 1 or not self.fully_sharded:
        return w2_lora_b
    # Based on S-LoRA, we slice W2 B along the hidden_size dim.
    # w2_lora_b shape (num_experts,output_size,rank)
    shard_size = self.w2_lora_b_stacked[0].shape[2]
    start_idx = self.tp_rank * shard_size
    end_idx = (self.tp_rank + 1) * shard_size

    return w2_lora_b[:, start_idx:end_idx, :]

can_replace_layer(source_layer, lora_config, packed_modules_list, model_config=None) classmethod

Returns True if the layer can be replaced by this LoRA layer.

Source code in vllm/lora/layers/fused_moe.py
@classmethod
def can_replace_layer(
    cls,
    source_layer: nn.Module,
    lora_config: LoRAConfig,
    packed_modules_list: list,
    model_config: PretrainedConfig | None = None,
) -> bool:
    """Returns True if the layer can be replaced by this LoRA layer."""

    # source_layer is MoERunner
    moe_cls = maybe_get_oot_by_class(MoERunner)
    return isinstance(source_layer, moe_cls) and len(packed_modules_list) == 2

create_lora_weights(max_loras, lora_config, model_config=None)

Initializes lora matrices.

Source code in vllm/lora/layers/fused_moe.py
def create_lora_weights(
    self,
    max_loras: int,
    lora_config: LoRAConfig,
    model_config: PretrainedConfig | None = None,
) -> None:
    """Initializes lora matrices."""

    self._verify_ep_fs(lora_config)
    self.max_loras = lora_config.max_loras
    self.fully_sharded = lora_config.fully_sharded_loras

    self.adapter_enabled = torch.tensor(
        [0] * (max_loras + 1), dtype=torch.int, device=self.device
    )

    self._create_lora_a_weights(max_loras, lora_config)
    self._create_lora_b_weights(max_loras, lora_config)
    # They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
    # to create a dummy LoRA weights.
    # TODO Optimize this section
    self.lora_a_stacked = []
    self.lora_b_stacked = []
    for lora_id in range(max_loras):
        for experts_id in range(self.local_num_experts):
            # For gated MoE: gate_proj (w1), down_proj (w2), up_proj (w3)
            # For non-gated MoE: up_proj (w1), down_proj (w2)
            self.lora_a_stacked.append(
                self.w13_lora_a_stacked[0][lora_id][experts_id]
            )
            self.lora_a_stacked.append(
                self.w2_lora_a_stacked[0][lora_id][experts_id]
            )

            self.lora_b_stacked.append(
                self.w13_lora_b_stacked[0][lora_id][experts_id]
            )
            self.lora_b_stacked.append(
                self.w2_lora_b_stacked[0][lora_id][experts_id]
            )

            # Only add w3 (up_proj) for gated MoE (_w13_slices == 2)
            if self._w13_slices == 2:
                self.lora_a_stacked.append(
                    self.w13_lora_a_stacked[1][lora_id][experts_id]
                )
                self.lora_b_stacked.append(
                    self.w13_lora_b_stacked[1][lora_id][experts_id]
                )

reset_lora(index)

Resets the lora weights at index back to 0.

Source code in vllm/lora/layers/fused_moe.py
def reset_lora(self, index: int):
    """Resets the lora weights at index back to 0."""
    for pos in range(self._w13_slices):
        self.w13_lora_a_stacked[pos][index] = 0
        self.w13_lora_b_stacked[pos][index] = 0

    self.w2_lora_a_stacked[0][index] = 0
    self.w2_lora_b_stacked[0][index] = 0
    self.adapter_enabled[index] = 0

set_lora(index, lora_a, lora_b)

Overwrites lora tensors at index.

Source code in vllm/lora/layers/fused_moe.py
def set_lora(
    self,
    index: int,
    lora_a: torch.Tensor | list[torch.Tensor],
    lora_b: torch.Tensor | list[torch.Tensor],
):
    """Overwrites lora tensors at index."""
    # Make mypy happy
    assert isinstance(lora_a, list)
    assert isinstance(lora_b, list)

    self.reset_lora(index)
    self.adapter_enabled[index] = 1

    num_experts = self.w13_lora_a_stacked[0].shape[1]

    w1_lora_a, w2_lora_a, w3_lora_a = lora_a
    w1_lora_b, w2_lora_b, w3_lora_b = lora_b

    # EP slicing is done once at add time in
    # LoRAModelManager._slice_moe_lora_ep, so by here the cached
    # tensors already match the local-expert dim of the stacked buffers.
    assert (
        num_experts
        == w1_lora_a.shape[0]
        == w2_lora_a.shape[0]
        == w3_lora_a.shape[0]
    )

    slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
    slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)

    sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
    sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)

    self.w13_lora_a_stacked[0][
        index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
    ].copy_(slliced_w1_lora_a, non_blocking=True)

    self.w13_lora_b_stacked[0][
        index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
    ].copy_(slliced_w1_lora_b, non_blocking=True)

    # Only copy w3 (up_proj) for gated MoE (_w13_slices == 2)
    if self._w13_slices == 2:
        slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
        slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)

        self.w13_lora_a_stacked[1][
            index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
        ].copy_(slliced_w3_lora_a, non_blocking=True)

        self.w13_lora_b_stacked[1][
            index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
        ].copy_(slliced_w3_lora_b, non_blocking=True)

    self.w2_lora_a_stacked[0][
        index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
    ].copy_(sliced_w2_lora_a, non_blocking=True)

    self.w2_lora_b_stacked[0][
        index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
    ].copy_(sliced_w2_lora_b, non_blocking=True)