import asyncio
import inspect
import json
import os
from collections.abc import AsyncIterator, Callable
from typing import Any
# Soft cap per-channel queue depth to avoid unbounded memory growth when a
# producer runs without consumers. Overflows drop the oldest item.
_QUEUE_MAXSIZE = max(1, int(os.getenv("PUBSUB_QUEUE_MAXSIZE", "32")))
# ───────────────────────── Pubsub backend ──────────────────────────
# Use a broad type to avoid runtime attribute access on stubbed modules
bus: Any | None = None # set by init_pubsub()
# ───────────────────────── In-memory fallback ─────────────────────────
class _MemPubSub:
def __init__(self, parent: "_InMemoryBus"):
self._parent = parent
self._channel: str | None = None
async def subscribe(self, channel: str):
self._channel = channel
self._parent._ensure(channel)
return True
async def listen(self):
ch = self._channel or ""
q: asyncio.Queue = self._parent._ensure(ch)
while True:
data = await q.get()
yield {"type": "message", "data": data}
class _InMemoryBus:
def __init__(self, maxsize: int = _QUEUE_MAXSIZE):
self._channels: dict[str, asyncio.Queue] = {}
self._maxsize = max(1, int(maxsize))
def _ensure(self, ch: str):
if ch not in self._channels:
self._channels[ch] = asyncio.Queue(maxsize=self._maxsize)
return self._channels[ch]
async def ping(self): # pragma: no cover - trivial
return True
async def publish(self, channel: str, message: str):
q = self._ensure(channel)
try:
q.put_nowait(message)
return
except asyncio.QueueFull:
# Drop the oldest item to make space for the newest snapshot.
try:
q.get_nowait()
except asyncio.QueueEmpty:
pass
try:
q.put_nowait(message)
except asyncio.QueueFull:
# If another coroutine already refilled the slot, drop quietly.
pass
def pubsub(self):
return _MemPubSub(self)
[docs]
async def init_pubsub() -> None:
"""
Create an in-memory pubsub backend. This is process-local only.
"""
global bus
bus = _InMemoryBus()
print("[pubsub] Using in-memory pubsub backend")
# ────────────────────────── Publishing helpers ─────────────────────────
[docs]
async def publish(channel: str, message: Any) -> None:
"""JSON-encode *message* and push it to *channel*."""
await bus.publish(channel, json.dumps(message))
[docs]
async def run_producer(
channel: str,
producer: Callable[..., Any],
poll_interval: float = 1.0,
) -> None:
"""
Continuously feed a *producer* into the in-memory pubsub backend.
• **async-generator functions** - every yielded item is published.
• **coroutine functions** - called every *poll_interval* s, their
returned value is published.
"""
if inspect.isasyncgenfunction(producer):
async for item in producer():
await publish(channel, item)
else:
while True:
try:
item = await producer()
await publish(channel, item)
except Exception as exc:
print(f"[pubsub] {producer.__name__} error: {exc!r}")
await asyncio.sleep(poll_interval)
# ─────────────────────────── Subscriber API ────────────────────────────
[docs]
async def subscribe(channel_name: str) -> AsyncIterator[Any]:
"""Yield decoded JSON messages arriving on *channel_name*."""
pubsub = bus.pubsub()
await pubsub.subscribe(channel_name)
async for message in pubsub.listen():
if message["type"] == "message":
try:
yield json.loads(message["data"])
except json.JSONDecodeError:
continue