Skip to content

vllm.entrypoints.speech_to_text.realtime.metrics

ASGI middleware for WebSocket Prometheus metrics.

Modeled after prometheus-fastapi-instrumentator, this middleware transparently instruments WebSocket endpoints with standard metrics without requiring changes to handler code.

NOTE: This module intentionally has zero vllm imports so that it can be extracted into a standalone package (similar to prometheus-fastapi-instrumentator) in the future. Please keep it that way.

Classes:

WebSocketMetricsMiddleware

Pure ASGI middleware that instruments WebSocket connections.

Tracks active connections (gauge), total connections (counter), and connection duration (histogram) for all WebSocket endpoints.

Usage::

app.add_middleware(WebSocketMetricsMiddleware)
Source code in vllm/entrypoints/speech_to_text/realtime/metrics.py
class WebSocketMetricsMiddleware:
    """Pure ASGI middleware that instruments WebSocket connections.

    Tracks active connections (gauge), total connections (counter),
    and connection duration (histogram) for all WebSocket endpoints.

    Usage::

        app.add_middleware(WebSocketMetricsMiddleware)
    """

    def __init__(self, app: ASGIApp) -> None:
        self.app = app

    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
        if scope["type"] != "websocket":
            return self.app(scope, receive, send)

        return self._handle_websocket(scope, receive, send)

    async def _handle_websocket(
        self, scope: Scope, receive: Receive, send: Send
    ) -> None:
        start_time: float | None = None

        async def send_wrapper(message: Message) -> None:
            nonlocal start_time
            if message["type"] == "websocket.accept":
                start_time = time.monotonic()
                _active_sessions.inc()
                _total_sessions.inc()
            await send(message)

        try:
            await self.app(scope, receive, send_wrapper)
        finally:
            if start_time is not None:
                _active_sessions.dec()
                _session_duration.observe(time.monotonic() - start_time)