Skip to content

vllm.v1.sample.logits_processor

Modules:

Classes:

AdapterLogitsProcessor

Bases: LogitsProcessor

Wrapper for per-request logits processors

To wrap a specific per-request logits processor, * Subclass AdapterLogitsProcessor * Implement self.is_argmax_invariant() base-class method * Implement self.new_req_logits_processor(params)

self.__init__(vllm_config, device, is_pin_memory) does not need to be overridden in general. However, to implement custom constructor behavior - especially any logic which operates on or stores vllm_config, device, or is_pin_memory - self.__init__(vllm_config, device, is_pin_memory) must be overridden and the override must call super().__init__(vllm_config, device, is_pin_memory)

Methods:

Source code in vllm/v1/sample/logits_processor/__init__.py
class AdapterLogitsProcessor(LogitsProcessor):
    """Wrapper for per-request logits processors

    To wrap a specific per-request logits processor,
    * Subclass `AdapterLogitsProcessor`
    * Implement `self.is_argmax_invariant()` base-class method
    * Implement `self.new_req_logits_processor(params)`

    `self.__init__(vllm_config, device, is_pin_memory)` does not need to be
    overridden in general. However, to implement custom constructor behavior -
    especially any logic which operates on or stores `vllm_config`, `device`,
    or `is_pin_memory` - `self.__init__(vllm_config, device, is_pin_memory)`
    must be overridden and the override must call
    `super().__init__(vllm_config, device, is_pin_memory)`
    """

    def __init__(
        self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
    ):
        """Subclass must invoke
        `super().__init__(vllm_config, device, is_pin_memory)`.

        Subclass constructor may find it useful to utilize the `vllm_config`,
        `device` and `is_pin_memory` argument. However regardless of whether
        these arguments are used, the vLLM logits processor interface requires
        all three arguments to be present.
        """

        # Map req index -> logits processor state
        #
        # State representation is a partial[Tensor] comprising a request-level
        # logits processor with the output token ids argument and (if required)
        # the prompt token ids argument pre-populated
        #
        # Note that the partial carries a *reference* to output token ids, and
        # will thus always operate on the list as it is currently, not as it
        # was when the partial was created.
        self.req_info: dict[int, partial[torch.Tensor]] = {}

    @abstractmethod
    def new_req_logits_processor(
        self,
        params: SamplingParams,
    ) -> RequestLogitsProcessor | None:
        """Consume request info; return a per-request logits processor.

        Return None if logits processor does not need to be applied to request

        Args:
          params: request sampling params

        Returns:
          None if logits processor should not be applied to request; otherwise
          returns a `RequestLogitsProcessor` instance

        """
        raise NotImplementedError

    def _new_state(
        self,
        params: SamplingParams,
        prompt_ids: list[int] | None,
        output_ids: list[int],
    ) -> partial[torch.Tensor] | None:
        """Return state representation for new request

        Returns None if logits processor is not applicable to request

        Args:
          params: request sampling params
          prompt_ids: request prompt token ids
          output_ids: decoded tokens so far for this request

        Returns:
          logits processor partial[Tensor] or None

        """
        if req_lp := self.new_req_logits_processor(params):
            if len(inspect.signature(req_lp).parameters) == 3:
                if prompt_ids is None:
                    raise ValueError(
                        "Prompt token ids are required for this "
                        "logits processor but were not provided."
                    )
                args = [prompt_ids, output_ids]
            else:
                args = [output_ids]
            return partial(req_lp, *args)
        return None

    def update_state(self, batch_update: BatchUpdate | None):
        process_dict_updates(
            self.req_info,
            batch_update,
            self._new_state,
        )

    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        if self.req_info:
            # Apply per-request logits processors to corresponding rows of
            # logits tensor
            for req_idx, req_lp in self.req_info.items():
                req_logits = logits[req_idx]
                new_logits = req_lp(req_logits)
                if new_logits is not req_logits:
                    # Modify logits tensor row in-place if necessary
                    logits[req_idx] = new_logits
        return logits

__init__(vllm_config, device, is_pin_memory)

Subclass must invoke super().__init__(vllm_config, device, is_pin_memory).

Subclass constructor may find it useful to utilize the vllm_config, device and is_pin_memory argument. However regardless of whether these arguments are used, the vLLM logits processor interface requires all three arguments to be present.

Source code in vllm/v1/sample/logits_processor/__init__.py
def __init__(
    self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
):
    """Subclass must invoke
    `super().__init__(vllm_config, device, is_pin_memory)`.

    Subclass constructor may find it useful to utilize the `vllm_config`,
    `device` and `is_pin_memory` argument. However regardless of whether
    these arguments are used, the vLLM logits processor interface requires
    all three arguments to be present.
    """

    # Map req index -> logits processor state
    #
    # State representation is a partial[Tensor] comprising a request-level
    # logits processor with the output token ids argument and (if required)
    # the prompt token ids argument pre-populated
    #
    # Note that the partial carries a *reference* to output token ids, and
    # will thus always operate on the list as it is currently, not as it
    # was when the partial was created.
    self.req_info: dict[int, partial[torch.Tensor]] = {}

_new_state(params, prompt_ids, output_ids)

Return state representation for new request

Returns None if logits processor is not applicable to request

Parameters:

  • params

    (SamplingParams) –

    request sampling params

  • prompt_ids

    (list[int] | None) –

    request prompt token ids

  • output_ids

    (list[int]) –

    decoded tokens so far for this request

Returns:

  • partial[Tensor] | None

    logits processor partial[Tensor] or None

Source code in vllm/v1/sample/logits_processor/__init__.py
def _new_state(
    self,
    params: SamplingParams,
    prompt_ids: list[int] | None,
    output_ids: list[int],
) -> partial[torch.Tensor] | None:
    """Return state representation for new request

    Returns None if logits processor is not applicable to request

    Args:
      params: request sampling params
      prompt_ids: request prompt token ids
      output_ids: decoded tokens so far for this request

    Returns:
      logits processor partial[Tensor] or None

    """
    if req_lp := self.new_req_logits_processor(params):
        if len(inspect.signature(req_lp).parameters) == 3:
            if prompt_ids is None:
                raise ValueError(
                    "Prompt token ids are required for this "
                    "logits processor but were not provided."
                )
            args = [prompt_ids, output_ids]
        else:
            args = [output_ids]
        return partial(req_lp, *args)
    return None

new_req_logits_processor(params) abstractmethod

Consume request info; return a per-request logits processor.

Return None if logits processor does not need to be applied to request

Parameters:

Returns:

  • LogitsProcessor | None

    None if logits processor should not be applied to request; otherwise

  • LogitsProcessor | None

    returns a RequestLogitsProcessor instance

Source code in vllm/v1/sample/logits_processor/__init__.py
@abstractmethod
def new_req_logits_processor(
    self,
    params: SamplingParams,
) -> RequestLogitsProcessor | None:
    """Consume request info; return a per-request logits processor.

    Return None if logits processor does not need to be applied to request

    Args:
      params: request sampling params

    Returns:
      None if logits processor should not be applied to request; otherwise
      returns a `RequestLogitsProcessor` instance

    """
    raise NotImplementedError

BatchUpdate dataclass

Persistent batch state change info for logitsprocs

Source code in vllm/v1/sample/logits_processor/interface.py
@dataclass(frozen=True)
class BatchUpdate:
    """Persistent batch state change info for logitsprocs"""

    batch_size: int  # Current num reqs in batch

    # Metadata for requests added to, removed from, and moved
    # within the persistent batch.
    #
    # Key assumption: the `output_tok_ids` list (which is an element of each
    # tuple in `added`) is a reference to the request's running output tokens
    # list; via this reference, the logits processors always see the latest
    # list of generated output tokens.
    #
    # NOTE:
    # * Added or moved requests may replace existing requests with the same
    #   index.
    # * Operations should be processed in the following order:
    #   - removed, added, moved
    removed: Sequence[RemovedRequest]
    added: Sequence[AddedRequest]
    moved: Sequence[MovedRequest]

BatchUpdateBuilder

Helps track persistent batch state changes and build a batch update data structure for logitsprocs Assumptions: * All information about requests removed from persistent batch during a step is aggregated in self._removed through calls to self.removed_append() at the beginning of a step. This must happen before the first time that self.removed, self.pop_removed() or self.peek_removed() are invoked in a given step * After the first time that self.removed, self.pop_removed() or self.peek_removed() are read in a step, no new removals are registered using self.removed_append() * Elements of self._removed are never directly modified, added or removed (i.e. modification is only via self.removed_append() and self.pop_removed()) Guarantees under above assumptions: * self.removed is always sorted in descending order * self.pop_removed() and self.peek_removed() both return the lowest removed request index in the current step

Methods:

  • get_and_reset

    Generate a logitsprocs batch update data structure and reset

  • peek_removed

    Return lowest removed request index

  • pop_removed

    Pop lowest removed request index

  • removed_append

    Register the removal of a request from the persistent batch.

  • reset

    Returns True if there were any changes to the batch.

Attributes:

  • removed (list[RemovedRequest]) –

    Removed request indices sorted in

Source code in vllm/v1/sample/logits_processor/state.py
class BatchUpdateBuilder:
    """Helps track persistent batch state changes and build
    a batch update data structure for logitsprocs
    Assumptions:
    * All information about requests removed from persistent batch
      during a step is aggregated in self._removed through calls to
      self.removed_append() at the beginning of a step. This must happen
      before the first time that self.removed, self.pop_removed()
      or self.peek_removed() are invoked in a given step
    * After the first time that self.removed, self.pop_removed()
      or self.peek_removed() are read in a step, no new removals
      are registered using self.removed_append()
    * Elements of self._removed are never directly modified, added or
      removed (i.e. modification is only via self.removed_append() and
      self.pop_removed())
    Guarantees under above assumptions:
    * self.removed is always sorted in descending order
    * self.pop_removed() and self.peek_removed() both return
      the lowest removed request index in the current step
    """

    _removed: list[RemovedRequest]
    _is_removed_sorted: bool
    added: list[AddedRequest]
    moved: list[MovedRequest]

    def __init__(
        self,
        removed: list[RemovedRequest] | None = None,
        added: list[AddedRequest] | None = None,
        moved: list[MovedRequest] | None = None,
    ) -> None:
        self._removed = removed or []
        self.added = added or []
        self.moved = moved or []
        self._is_removed_sorted = False

        # Used to track changes in the pooling case
        # where we don't populate the added list.
        self.batch_changed = False

    def _ensure_removed_sorted(self) -> None:
        """Sort removed request indices in
        descending order.
        Idempotent after first call in a
        given step, until reset.
        """
        if not self._is_removed_sorted:
            self._removed.sort(reverse=True)
            self._is_removed_sorted = True

    @property
    def removed(self) -> list[RemovedRequest]:
        """Removed request indices sorted in
        descending order"""
        self._ensure_removed_sorted()
        return self._removed

    def removed_append(self, index: int) -> None:
        """Register the removal of a request from the persistent batch.

        Must not be called after the first time self.removed,
        self.pop_removed() or self.peek_removed() are invoked.

        Args:
          index: request index
        """
        if self._is_removed_sorted:
            raise RuntimeError(
                "Cannot register new removed request after self.removed has been read."
            )
        self._removed.append(index)
        self.batch_changed = True

    def has_removed(self) -> bool:
        return bool(self._removed)

    def peek_removed(self) -> int | None:
        """Return lowest removed request index"""
        if self.has_removed():
            self._ensure_removed_sorted()
            return self._removed[-1]
        return None

    def pop_removed(self) -> int | None:
        """Pop lowest removed request index"""
        if self.has_removed():
            self._ensure_removed_sorted()
            return self._removed.pop()
        return None

    def reset(self) -> bool:
        """Returns True if there were any changes to the batch."""
        self._is_removed_sorted = False
        self._removed.clear()
        self.added.clear()
        self.moved.clear()
        batch_changed = self.batch_changed
        self.batch_changed = False
        return batch_changed

    def get_and_reset(self, batch_size: int) -> BatchUpdate | None:
        """Generate a logitsprocs batch update data structure and reset
        internal batch update builder state.

        Args:
          batch_size: current persistent batch size

        Returns:
          Frozen logitsprocs batch update instance; `None` if no updates
        """
        # Reset removal-sorting logic
        self._is_removed_sorted = False
        self.batch_changed = False
        if not any((self._removed, self.moved, self.added)):
            # No update; short-circuit
            return None
        # Build batch state update
        batch_update = BatchUpdate(
            batch_size=batch_size,
            removed=self._removed,
            moved=self.moved,
            added=self.added,
        )
        self._removed = []
        self.moved = []
        self.added = []
        return batch_update

removed property

Removed request indices sorted in descending order

_ensure_removed_sorted()

Sort removed request indices in descending order. Idempotent after first call in a given step, until reset.

Source code in vllm/v1/sample/logits_processor/state.py
def _ensure_removed_sorted(self) -> None:
    """Sort removed request indices in
    descending order.
    Idempotent after first call in a
    given step, until reset.
    """
    if not self._is_removed_sorted:
        self._removed.sort(reverse=True)
        self._is_removed_sorted = True

get_and_reset(batch_size)

Generate a logitsprocs batch update data structure and reset internal batch update builder state.

Parameters:

  • batch_size

    (int) –

    current persistent batch size

Returns:

  • BatchUpdate | None

    Frozen logitsprocs batch update instance; None if no updates

Source code in vllm/v1/sample/logits_processor/state.py
def get_and_reset(self, batch_size: int) -> BatchUpdate | None:
    """Generate a logitsprocs batch update data structure and reset
    internal batch update builder state.

    Args:
      batch_size: current persistent batch size

    Returns:
      Frozen logitsprocs batch update instance; `None` if no updates
    """
    # Reset removal-sorting logic
    self._is_removed_sorted = False
    self.batch_changed = False
    if not any((self._removed, self.moved, self.added)):
        # No update; short-circuit
        return None
    # Build batch state update
    batch_update = BatchUpdate(
        batch_size=batch_size,
        removed=self._removed,
        moved=self.moved,
        added=self.added,
    )
    self._removed = []
    self.moved = []
    self.added = []
    return batch_update

peek_removed()

Return lowest removed request index

Source code in vllm/v1/sample/logits_processor/state.py
def peek_removed(self) -> int | None:
    """Return lowest removed request index"""
    if self.has_removed():
        self._ensure_removed_sorted()
        return self._removed[-1]
    return None

pop_removed()

Pop lowest removed request index

Source code in vllm/v1/sample/logits_processor/state.py
def pop_removed(self) -> int | None:
    """Pop lowest removed request index"""
    if self.has_removed():
        self._ensure_removed_sorted()
        return self._removed.pop()
    return None

removed_append(index)

Register the removal of a request from the persistent batch.

Must not be called after the first time self.removed, self.pop_removed() or self.peek_removed() are invoked.

Parameters:

  • index

    (int) –

    request index

Source code in vllm/v1/sample/logits_processor/state.py
def removed_append(self, index: int) -> None:
    """Register the removal of a request from the persistent batch.

    Must not be called after the first time self.removed,
    self.pop_removed() or self.peek_removed() are invoked.

    Args:
      index: request index
    """
    if self._is_removed_sorted:
        raise RuntimeError(
            "Cannot register new removed request after self.removed has been read."
        )
    self._removed.append(index)
    self.batch_changed = True

reset()

Returns True if there were any changes to the batch.

Source code in vllm/v1/sample/logits_processor/state.py
def reset(self) -> bool:
    """Returns True if there were any changes to the batch."""
    self._is_removed_sorted = False
    self._removed.clear()
    self.added.clear()
    self.moved.clear()
    batch_changed = self.batch_changed
    self.batch_changed = False
    return batch_changed

LogitBiasLogitsProcessor

Bases: LogitsProcessor

Methods:

Source code in vllm/v1/sample/logits_processor/builtin.py
class LogitBiasLogitsProcessor(LogitsProcessor):
    def __init__(self, _, device: torch.device, is_pin_memory: bool):
        self.device = device
        self.pin_memory = is_pin_memory
        self.biases: dict[int, dict[int, float]] = {}

        self.bias_tensor: torch.Tensor = torch.tensor(())
        self.logits_slice = (
            self._device_tensor([], torch.int32),
            self._device_tensor([], torch.int32),
        )

    def is_argmax_invariant(self) -> bool:
        """Logit bias can rebalance token probabilities and change the
        outcome of argmax in greedy sampling."""
        return False

    def update_state(self, batch_update: BatchUpdate | None):
        needs_update = process_dict_updates(
            self.biases, batch_update, lambda params, _, __: params.logit_bias or None
        )

        # Update tensors if needed.
        if needs_update:
            reqs: list[int] = []
            tok_ids: list[int] = []
            biases: list[float] = []
            for req, lb in self.biases.items():
                reqs.extend([req] * len(lb))
                tok_ids.extend(lb.keys())
                biases.extend(lb.values())

            self.bias_tensor = self._device_tensor(biases, torch.float32)
            self.logits_slice = (
                self._device_tensor(reqs, torch.int32),
                self._device_tensor(tok_ids, torch.int32),
            )

    def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
        return torch.tensor(
            data, device="cpu", dtype=dtype, pin_memory=self.pin_memory
        ).to(device=self.device, non_blocking=True)

    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        if self.biases:
            logits[self.logits_slice] += self.bias_tensor
        return logits

is_argmax_invariant()

Logit bias can rebalance token probabilities and change the outcome of argmax in greedy sampling.

Source code in vllm/v1/sample/logits_processor/builtin.py
def is_argmax_invariant(self) -> bool:
    """Logit bias can rebalance token probabilities and change the
    outcome of argmax in greedy sampling."""
    return False

LogitsProcessor

Bases: ABC

Methods:

  • apply

    Apply LogitsProcessor to batch logits tensor.

  • is_argmax_invariant

    True if logits processor has no impact on the

  • update_state

    Called when there are new output tokens, prior

  • validate_params

    Validate sampling params for this logits processor.

Source code in vllm/v1/sample/logits_processor/interface.py
class LogitsProcessor(ABC):
    @classmethod
    def validate_params(cls, sampling_params: SamplingParams):
        """Validate sampling params for this logits processor.

        Raise ValueError for invalid ones.
        """
        return None

    @abstractmethod
    def __init__(
        self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
    ) -> None:
        raise NotImplementedError

    @abstractmethod
    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        """Apply LogitsProcessor to batch logits tensor.

        The updated tensor must be returned but may be
        modified in-place.
        """
        raise NotImplementedError

    @abstractmethod
    def is_argmax_invariant(self) -> bool:
        """True if logits processor has no impact on the
        argmax computation in greedy sampling.
        NOTE: may or may not have the same value for all
        instances of a given LogitsProcessor subclass,
        depending on subclass implementation.
        """
        raise NotImplementedError

    @abstractmethod
    def update_state(
        self,
        batch_update: "BatchUpdate | None",
    ) -> None:
        """Called when there are new output tokens, prior
        to each forward pass.

        Args:
            batch_update: Non-None iff there have been changes
                to the batch makeup.
        """
        raise NotImplementedError

apply(logits) abstractmethod

Apply LogitsProcessor to batch logits tensor.

The updated tensor must be returned but may be modified in-place.

Source code in vllm/v1/sample/logits_processor/interface.py
@abstractmethod
def apply(self, logits: torch.Tensor) -> torch.Tensor:
    """Apply LogitsProcessor to batch logits tensor.

    The updated tensor must be returned but may be
    modified in-place.
    """
    raise NotImplementedError

is_argmax_invariant() abstractmethod

True if logits processor has no impact on the argmax computation in greedy sampling. NOTE: may or may not have the same value for all instances of a given LogitsProcessor subclass, depending on subclass implementation.

Source code in vllm/v1/sample/logits_processor/interface.py
@abstractmethod
def is_argmax_invariant(self) -> bool:
    """True if logits processor has no impact on the
    argmax computation in greedy sampling.
    NOTE: may or may not have the same value for all
    instances of a given LogitsProcessor subclass,
    depending on subclass implementation.
    """
    raise NotImplementedError

update_state(batch_update) abstractmethod

Called when there are new output tokens, prior to each forward pass.

Parameters:

  • batch_update

    (BatchUpdate | None) –

    Non-None iff there have been changes to the batch makeup.

Source code in vllm/v1/sample/logits_processor/interface.py
@abstractmethod
def update_state(
    self,
    batch_update: "BatchUpdate | None",
) -> None:
    """Called when there are new output tokens, prior
    to each forward pass.

    Args:
        batch_update: Non-None iff there have been changes
            to the batch makeup.
    """
    raise NotImplementedError

validate_params(sampling_params) classmethod

Validate sampling params for this logits processor.

Raise ValueError for invalid ones.

Source code in vllm/v1/sample/logits_processor/interface.py
@classmethod
def validate_params(cls, sampling_params: SamplingParams):
    """Validate sampling params for this logits processor.

    Raise ValueError for invalid ones.
    """
    return None

LogitsProcessors

Encapsulates initialized logitsproc objects.

Attributes:

Source code in vllm/v1/sample/logits_processor/state.py
class LogitsProcessors:
    """Encapsulates initialized logitsproc objects."""

    def __init__(self, logitsprocs: Iterable["LogitsProcessor"] | None = None) -> None:
        self.argmax_invariant: list[LogitsProcessor] = []
        self.non_argmax_invariant: list[LogitsProcessor] = []
        if logitsprocs:
            for logitproc in logitsprocs:
                (
                    self.argmax_invariant
                    if logitproc.is_argmax_invariant()
                    else self.non_argmax_invariant
                ).append(logitproc)

    @property
    def all(self) -> Iterator["LogitsProcessor"]:
        """Iterator over all logits processors."""
        return chain(self.argmax_invariant, self.non_argmax_invariant)

all property

Iterator over all logits processors.

MinPLogitsProcessor

Bases: LogitsProcessor

Methods:

Source code in vllm/v1/sample/logits_processor/builtin.py
class MinPLogitsProcessor(LogitsProcessor):
    def __init__(
        self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
    ):
        max_num_reqs = vllm_config.scheduler_config.max_num_seqs
        self.min_p_count: int = 0

        self.min_p_cpu_tensor = torch.zeros(
            (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=is_pin_memory
        )
        self.min_p_cpu = self.min_p_cpu_tensor.numpy()

        self.use_double_tensor = torch.device(device).type != "cpu"

        if self.use_double_tensor:
            # Pre-allocated device tensor
            self.min_p_device: torch.Tensor = torch.empty(
                (max_num_reqs,), dtype=torch.float32, device=device
            )
        else:
            self.min_p_device = self.min_p_cpu_tensor
        # Current slice of the device tensor
        self.min_p: torch.Tensor = self.min_p_device[:0]

    def is_argmax_invariant(self) -> bool:
        """Min-p never impacts greedy sampling"""
        return True

    def get_min_p_by_index(self, index: int) -> float:
        return float(self.min_p_cpu[index])

    def update_state(self, batch_update: BatchUpdate | None):
        if not batch_update:
            return

        needs_update = False
        # Process added requests.
        for index, params, _, _ in batch_update.added:
            min_p = params.min_p
            min_p_before = self.min_p_cpu[index]
            if min_p_before != min_p:
                needs_update = True
                self.min_p_cpu[index] = min_p
                if min_p and not min_p_before:
                    self.min_p_count += 1
                elif not min_p and min_p_before:
                    self.min_p_count -= 1

        if self.min_p_count:
            # Process removed requests.
            if batch_update.removed:
                needs_update = True
                for index in batch_update.removed:
                    if self.min_p_cpu[index]:
                        self.min_p_cpu[index] = 0
                        self.min_p_count -= 1

            # Process moved requests, unidirectional (a->b) and swap (a<->b).
            for adx, bdx, direct in batch_update.moved:
                min_p_a, min_p_b = self.min_p_cpu[adx], self.min_p_cpu[bdx]
                if min_p_a != min_p_b:
                    needs_update = True
                    self.min_p_cpu[bdx] = min_p_a
                    if direct == MoveDirectionality.SWAP:
                        self.min_p_cpu[adx] = min_p_b
                if direct == MoveDirectionality.UNIDIRECTIONAL:
                    if min_p_a:
                        self.min_p_cpu[adx] = 0
                    if min_p_b:
                        self.min_p_count -= 1

        # Update tensors if needed.
        size = batch_update.batch_size
        if self.min_p_count and (needs_update or self.min_p.shape[0] != size):
            self.min_p = self.min_p_device[:size]
            if self.use_double_tensor:
                self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True)
            self.min_p.unsqueeze_(1)

    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        if not self.min_p_count:
            return logits

        # Convert logits to probability distribution
        probability_values = torch.nn.functional.softmax(logits, dim=-1)
        # Calculate maximum probabilities per sequence
        max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True)
        # Adjust min_p
        adjusted_min_p = max_probabilities.mul_(self.min_p)
        # Identify valid tokens using threshold comparison
        invalid_token_mask = probability_values < adjusted_min_p
        # Apply mask using boolean indexing
        logits.masked_fill_(invalid_token_mask, -float("inf"))
        return logits

is_argmax_invariant()

Min-p never impacts greedy sampling

Source code in vllm/v1/sample/logits_processor/builtin.py
def is_argmax_invariant(self) -> bool:
    """Min-p never impacts greedy sampling"""
    return True

MinTokensLogitsProcessor

Bases: LogitsProcessor

Methods:

Source code in vllm/v1/sample/logits_processor/builtin.py
class MinTokensLogitsProcessor(LogitsProcessor):
    def __init__(
        self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
    ):
        # index -> (min_toks, output_token_ids, stop_token_ids)
        self.device = device
        self.pin_memory = is_pin_memory
        self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {}

        # (req_idx_tensor,eos_tok_id_tensor)
        self.logits_slice: tuple[torch.Tensor, torch.Tensor] = (
            self._device_tensor([], torch.int32),
            self._device_tensor([], torch.int32),
        )

        self.neg_inf_tensor = torch.tensor(
            -float("inf"), dtype=torch.float32, device=self.device
        )

    def is_argmax_invariant(self) -> bool:
        """By censoring stop tokens, min-tokens can change the outcome
        of the argmax operation in greedy sampling."""
        return False

    @staticmethod
    def add_request(
        params: SamplingParams, _: list[int] | None, output_tok_ids: list[int]
    ) -> tuple[int, Sequence[int], set[int]] | None:
        min_tokens = params.min_tokens
        if not min_tokens or len(output_tok_ids) >= min_tokens:
            return None
        return min_tokens, output_tok_ids, params.all_stop_token_ids

    def update_state(self, batch_update: BatchUpdate | None):
        needs_update = process_dict_updates(
            self.min_toks, batch_update, self.add_request
        )
        if self.min_toks:
            # Check for any requests that have attained their min tokens.
            to_remove = tuple(
                index
                for index, (min_toks, out_tok_ids, _) in self.min_toks.items()
                if len(out_tok_ids) >= min_toks
            )
            if to_remove:
                needs_update = True
                for index in to_remove:
                    del self.min_toks[index]

        # Update tensors if needed.
        if needs_update:
            reqs: list[int] = []
            tok_ids: list[int] = []
            for req, (_, _, stop_tok_ids) in self.min_toks.items():
                reqs.extend([req] * len(stop_tok_ids))
                tok_ids.extend(stop_tok_ids)

            self.logits_slice = (
                self._device_tensor(reqs, torch.int32),
                self._device_tensor(tok_ids, torch.int32),
            )

    def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor:
        return torch.tensor(
            data, device="cpu", dtype=dtype, pin_memory=self.pin_memory
        ).to(device=self.device, non_blocking=True)

    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        if self.min_toks:
            # Inhibit EOS token for requests which have not reached min length
            logits.index_put_(self.logits_slice, self.neg_inf_tensor)
        return logits

    def apply_with_spec_decode(
        self,
        logits: torch.Tensor,
        num_draft_tokens: list[int],
    ) -> torch.Tensor:
        """Spec-decode version of apply().
        Priority: ``min_tokens`` > ``stop_token_ids`` / EOS.
        Example: ``num_draft_tokens = [2, 3, 1]``
          → ``logits`` shape ``[6, V]``, ``cumsum = [0, 2, 5, 6]``
          → request 0 owns rows 0‑1, request 1 rows 2‑4, request 2 row 5.
        """
        if not self.min_toks:
            return logits

        num_draft_arr = np.array(num_draft_tokens, dtype=np.int64)
        cumsum = np.concatenate([[0], np.cumsum(num_draft_arr)])

        entries = [
            (req_idx, min_tok, len(out_tok_ids), list(stop_tok_ids))
            for req_idx, (min_tok, out_tok_ids, stop_tok_ids) in self.min_toks.items()
            if stop_tok_ids
        ]

        if not entries:
            return logits

        all_rows: list[np.ndarray] = []  # row indices to mask
        all_toks: list[np.ndarray] = []  # stop-token ids at those rows

        for req_idx, min_tok, current_len, stop_toks in entries:
            remaining = min_tok - current_len
            # How many leading draft positions still need stop-token masking.
            n_mask = int(min(max(remaining, 0), num_draft_arr[req_idx]))

            if n_mask > 0:
                offset = cumsum[req_idx]
                row_indices = np.arange(offset, offset + n_mask, dtype=np.int64)
                n_stop = len(stop_toks)
                all_rows.append(np.repeat(row_indices, n_stop))
                all_toks.append(np.tile(stop_toks, n_mask))

        if all_rows:
            rows_arr = np.concatenate(all_rows)
            toks_arr = np.concatenate(all_toks)
            # (row_indices, token_indices) for index_put_ to set -inf.
            logits_slice = (
                torch.from_numpy(rows_arr).to(self.device, non_blocking=True),
                torch.from_numpy(toks_arr).to(self.device, non_blocking=True),
            )
            logits.index_put_(logits_slice, self.neg_inf_tensor)

        return logits

apply_with_spec_decode(logits, num_draft_tokens)

Spec-decode version of apply(). Priority: min_tokens > stop_token_ids / EOS. Example: num_draft_tokens = [2, 3, 1]logits shape [6, V], cumsum = [0, 2, 5, 6] → request 0 owns rows 0‑1, request 1 rows 2‑4, request 2 row 5.

Source code in vllm/v1/sample/logits_processor/builtin.py
def apply_with_spec_decode(
    self,
    logits: torch.Tensor,
    num_draft_tokens: list[int],
) -> torch.Tensor:
    """Spec-decode version of apply().
    Priority: ``min_tokens`` > ``stop_token_ids`` / EOS.
    Example: ``num_draft_tokens = [2, 3, 1]``
      → ``logits`` shape ``[6, V]``, ``cumsum = [0, 2, 5, 6]``
      → request 0 owns rows 0‑1, request 1 rows 2‑4, request 2 row 5.
    """
    if not self.min_toks:
        return logits

    num_draft_arr = np.array(num_draft_tokens, dtype=np.int64)
    cumsum = np.concatenate([[0], np.cumsum(num_draft_arr)])

    entries = [
        (req_idx, min_tok, len(out_tok_ids), list(stop_tok_ids))
        for req_idx, (min_tok, out_tok_ids, stop_tok_ids) in self.min_toks.items()
        if stop_tok_ids
    ]

    if not entries:
        return logits

    all_rows: list[np.ndarray] = []  # row indices to mask
    all_toks: list[np.ndarray] = []  # stop-token ids at those rows

    for req_idx, min_tok, current_len, stop_toks in entries:
        remaining = min_tok - current_len
        # How many leading draft positions still need stop-token masking.
        n_mask = int(min(max(remaining, 0), num_draft_arr[req_idx]))

        if n_mask > 0:
            offset = cumsum[req_idx]
            row_indices = np.arange(offset, offset + n_mask, dtype=np.int64)
            n_stop = len(stop_toks)
            all_rows.append(np.repeat(row_indices, n_stop))
            all_toks.append(np.tile(stop_toks, n_mask))

    if all_rows:
        rows_arr = np.concatenate(all_rows)
        toks_arr = np.concatenate(all_toks)
        # (row_indices, token_indices) for index_put_ to set -inf.
        logits_slice = (
            torch.from_numpy(rows_arr).to(self.device, non_blocking=True),
            torch.from_numpy(toks_arr).to(self.device, non_blocking=True),
        )
        logits.index_put_(logits_slice, self.neg_inf_tensor)

    return logits

is_argmax_invariant()

By censoring stop tokens, min-tokens can change the outcome of the argmax operation in greedy sampling.

Source code in vllm/v1/sample/logits_processor/builtin.py
def is_argmax_invariant(self) -> bool:
    """By censoring stop tokens, min-tokens can change the outcome
    of the argmax operation in greedy sampling."""
    return False

_load_custom_logitsprocs(logits_processors)

Load all custom logits processors.

  • First load all installed logitproc plugins
  • Second load custom logitsprocs pass by the user at initialization time

Parameters:

  • logits_processors

    (Sequence[str | type[LogitsProcessor]] | None) –

    potentially mixed list of logitproc types and logitproc type fully-qualified names (FQCNs) which need to be loaded

Returns:

Source code in vllm/v1/sample/logits_processor/__init__.py
def _load_custom_logitsprocs(
    logits_processors: Sequence[str | type[LogitsProcessor]] | None,
) -> list[type[LogitsProcessor]]:
    """Load all custom logits processors.

    * First load all installed logitproc plugins
    * Second load custom logitsprocs pass by the user at initialization time

    Args:
      logits_processors: potentially mixed list of logitproc types and
                         logitproc type fully-qualified names (FQCNs)
                         which need to be loaded

    Returns:
      A list of all loaded logitproc types
    """
    from vllm.platforms import current_platform

    if current_platform.is_tpu():
        # No logitsprocs specified by caller
        # TODO(andy) - vLLM V1 on TPU does not support custom logitsprocs
        return []

    return _load_logitsprocs_plugins() + _load_logitsprocs_by_fqcns(logits_processors)

_load_logitsprocs_by_fqcns(logits_processors)

Load logit processor types, identifying them by fully-qualified class names (FQCNs).

Effectively, a mixed list of logitproc types and FQCN strings is converted into a list of entirely logitproc types, by loading from the FQCNs.

FQCN syntax is : i.e. x.y.z:CustomLogitProc

Already-loaded logitproc types must be subclasses of LogitsProcessor

Parameters:

Returns:

Source code in vllm/v1/sample/logits_processor/__init__.py
def _load_logitsprocs_by_fqcns(
    logits_processors: Sequence[str | type[LogitsProcessor]] | None,
) -> list[type[LogitsProcessor]]:
    """Load logit processor types, identifying them by fully-qualified class
    names (FQCNs).

    Effectively, a mixed list of logitproc types and FQCN strings is converted
    into a list of entirely logitproc types, by loading from the FQCNs.

    FQCN syntax is <module>:<type> i.e. x.y.z:CustomLogitProc

    Already-loaded logitproc types must be subclasses of LogitsProcessor

    Args:
      logits_processors: Potentially mixed list of logitsprocs types and FQCN
                         strings for logitproc types

    Returns:
      List of logitproc types

    """
    if not logits_processors:
        return []

    logger.debug(
        "%s additional custom logits processors specified, checking whether "
        "they need to be loaded.",
        len(logits_processors),
    )

    classes: list[type[LogitsProcessor]] = []
    for ldx, logitproc in enumerate(logits_processors):
        if isinstance(logitproc, type):
            logger.debug(" - Already-loaded logit processor: %s", logitproc.__name__)
            if not issubclass(logitproc, LogitsProcessor):
                raise ValueError(
                    f"{logitproc.__name__} is not a subclass of LogitsProcessor"
                )
            classes.append(logitproc)
            continue

        logger.debug("- Loading logits processor %s", logitproc)
        module_path, qualname = logitproc.split(":")

        try:
            # Load module
            with guard_cuda_initialization():
                module = importlib.import_module(module_path)
        except Exception as e:
            logger.error(
                "Failed to load %sth LogitsProcessor plugin %s: %s",
                ldx,
                logitproc,
                e,
            )
            raise RuntimeError(
                f"Failed to load {ldx}th LogitsProcessor plugin {logitproc}"
            ) from e

        # Walk down dotted name to get logitproc class
        obj = module
        for attr in qualname.split("."):
            obj = getattr(obj, attr)
        if not isinstance(obj, type):
            raise ValueError("Loaded logit processor must be a type.")
        if not issubclass(obj, LogitsProcessor):
            raise ValueError(f"{obj.__name__} must be a subclass of LogitsProcessor")
        classes.append(obj)

    return classes

_load_logitsprocs_plugins()

Load all installed logit processor plugins

Source code in vllm/v1/sample/logits_processor/__init__.py
def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]:
    """Load all installed logit processor plugins"""

    from importlib.metadata import entry_points

    installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP)
    if len(installed_logitsprocs_plugins) == 0:
        logger.debug("No logitsprocs plugins installed (group %s).", LOGITSPROCS_GROUP)
        return []

    # Load logitsprocs plugins
    logger.debug("Loading installed logitsprocs plugins (group %s):", LOGITSPROCS_GROUP)
    classes: list[type[LogitsProcessor]] = []
    for entrypoint in installed_logitsprocs_plugins:
        try:
            logger.debug(
                "- Loading logitproc plugin entrypoint=%s target=%s",
                entrypoint.name,
                entrypoint.value,
            )
            with guard_cuda_initialization():
                classes.append(entrypoint.load())
        except Exception as e:
            logger.error("Failed to load LogitsProcessor plugin %s: %s", entrypoint, e)
            raise RuntimeError(
                f"Failed to load LogitsProcessor plugin {entrypoint}"
            ) from e
    return classes