Source code for pubsub

import asyncio
import inspect
import json
import os
from collections.abc import AsyncIterator, Callable
from typing import Any

# Soft cap per-channel queue depth to avoid unbounded memory growth when a
# producer runs without consumers. Overflows drop the oldest item.
_QUEUE_MAXSIZE = max(1, int(os.getenv("PUBSUB_QUEUE_MAXSIZE", "32")))

# ───────────────────────── Pubsub backend ──────────────────────────
# Use a broad type to avoid runtime attribute access on stubbed modules
bus: Any | None = None  # set by init_pubsub()


# ───────────────────────── In-memory fallback ─────────────────────────
class _MemPubSub:
    def __init__(self, parent: "_InMemoryBus"):
        self._parent = parent
        self._channel: str | None = None

    async def subscribe(self, channel: str):
        self._channel = channel
        self._parent._ensure(channel)
        return True

    async def listen(self):
        ch = self._channel or ""
        q: asyncio.Queue = self._parent._ensure(ch)
        while True:
            data = await q.get()
            yield {"type": "message", "data": data}


class _InMemoryBus:
    def __init__(self, maxsize: int = _QUEUE_MAXSIZE):
        self._channels: dict[str, asyncio.Queue] = {}
        self._maxsize = max(1, int(maxsize))

    def _ensure(self, ch: str):
        if ch not in self._channels:
            self._channels[ch] = asyncio.Queue(maxsize=self._maxsize)
        return self._channels[ch]

    async def ping(self):  # pragma: no cover - trivial
        return True

    async def publish(self, channel: str, message: str):
        q = self._ensure(channel)
        try:
            q.put_nowait(message)
            return
        except asyncio.QueueFull:
            # Drop the oldest item to make space for the newest snapshot.
            try:
                q.get_nowait()
            except asyncio.QueueEmpty:
                pass
            try:
                q.put_nowait(message)
            except asyncio.QueueFull:
                # If another coroutine already refilled the slot, drop quietly.
                pass

    def pubsub(self):
        return _MemPubSub(self)


[docs] async def init_pubsub() -> None: """ Create an in-memory pubsub backend. This is process-local only. """ global bus bus = _InMemoryBus() print("[pubsub] Using in-memory pubsub backend")
# ────────────────────────── Publishing helpers ─────────────────────────
[docs] async def publish(channel: str, message: Any) -> None: """JSON-encode *message* and push it to *channel*.""" await bus.publish(channel, json.dumps(message))
[docs] async def run_producer( channel: str, producer: Callable[..., Any], poll_interval: float = 1.0, ) -> None: """ Continuously feed a *producer* into the in-memory pubsub backend. • **async-generator functions** - every yielded item is published. • **coroutine functions** - called every *poll_interval* s, their returned value is published. """ if inspect.isasyncgenfunction(producer): async for item in producer(): await publish(channel, item) else: while True: try: item = await producer() await publish(channel, item) except Exception as exc: print(f"[pubsub] {producer.__name__} error: {exc!r}") await asyncio.sleep(poll_interval)
# ─────────────────────────── Subscriber API ────────────────────────────
[docs] async def subscribe(channel_name: str) -> AsyncIterator[Any]: """Yield decoded JSON messages arriving on *channel_name*.""" pubsub = bus.pubsub() await pubsub.subscribe(channel_name) async for message in pubsub.listen(): if message["type"] == "message": try: yield json.loads(message["data"]) except json.JSONDecodeError: continue