Source code for config_editor

# config_editor.py
from __future__ import annotations

import asyncio
import difflib
import hashlib
import hmac
import json
import logging
import os
import signal
import time
from typing import Any

from fastapi import (
    APIRouter,
    Body,
    Depends,
    HTTPException,
    Query,
    Request,
    status,
)
from fastapi.responses import JSONResponse

import endpoint_config as ec
from utils.validation import ValidationError, validate_payload

router = APIRouter()
log = logging.getLogger("config_editor")

# ───────────── session / csrf verification (shared with app_factory) ─────────
_CFG_UI_SECRET = os.getenv("CFG_UI_SECRET") or os.getenv(
    "JWT_SECRET", "replace-me"
)
_CFG_UI_TTL_MIN = int(os.getenv("CFG_UI_TTL_MIN", "30"))


def _sign(payload: str) -> str:
    return hmac.new(
        _CFG_UI_SECRET.encode(), payload.encode(), hashlib.sha256
    ).hexdigest()


def _valid_session_cookie(val: str | None) -> bool:
    if not val or "." not in val:
        return False
    ts_str, nonce, sig = val.split(".", 2)
    body = f"{ts_str}.{nonce}"
    if _sign(body) != sig:
        return False
    try:
        ts = int(ts_str)
    except ValueError:
        return False
    return (time.time() - ts) <= (_CFG_UI_TTL_MIN * 60)


def _require_cfg_session(request: Request):
    ok = _valid_session_cookie(request.cookies.get("cfgui_session"))
    if not ok:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Config UI login required",
        )


def _require_csrf(request: Request):
    token_cookie = request.cookies.get("cfgui_csrf")
    token_hdr = request.headers.get("X-CSRF-Token")
    if not token_cookie or not token_hdr or token_cookie != token_hdr:
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN, detail="Invalid CSRF token"
        )


def _require_same_origin(request: Request):
    """
    Enforce same-origin on state-changing endpoints to complement CSRF.
    """
    hdr_host = request.headers.get("host") or "-"
    hdr_xf_host = request.headers.get("x-forwarded-host") or "-"
    hdr_xf_proto = request.headers.get("x-forwarded-proto") or "-"
    hdr_forwarded = request.headers.get("forwarded")

    origin = request.headers.get("origin")
    if not origin:
        # Non-browser clients must include Origin
        log.warning(
            "Config UI same-origin check failed: missing Origin "
            "(host=%s x-forwarded-host=%s x-forwarded-proto=%s forwarded=%s)",
            hdr_host,
            hdr_xf_host,
            hdr_xf_proto,
            hdr_forwarded,
        )
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail=(
                "Missing Origin header "
                f"(host={hdr_host}, "
                f"x-forwarded-host={hdr_xf_host}, "
                f"x-forwarded-proto={hdr_xf_proto})"
            ),
        )
    # Collect candidate origin values based on direct and forwarded headers.
    schemes: set[str] = {request.url.scheme}
    hosts: set[str] = set()

    host = request.headers.get("host")
    if host:
        hosts.add(host)

    xf_proto = request.headers.get("x-forwarded-proto")
    if xf_proto:
        schemes.update(h.strip() for h in xf_proto.split(",") if h.strip())

    xf_host = request.headers.get("x-forwarded-host")
    if xf_host:
        hosts.update(h.strip() for h in xf_host.split(",") if h.strip())

    forwarded = hdr_forwarded
    if forwarded:
        for part in forwarded.split(","):
            kvs = {
                k.strip(): v.strip()
                for k, _, v in (
                    item.partition("=") for item in part.split(";")
                )
                if k and v
            }
            if "proto" in kvs:
                schemes.add(kvs["proto"].strip('"'))
            if "host" in kvs:
                hosts.add(kvs["host"].strip('"'))

    expected = {
        f"{scheme}://{hst}"
        for scheme in schemes
        for hst in hosts
        if scheme and hst
    }
    if origin not in expected:
        log.warning(
            "Config UI same-origin check failed: origin=%s expected=%s "
            "(request_scheme=%s host=%s x-forwarded-host=%s "
            "x-forwarded-proto=%s forwarded=%s)",
            origin,
            sorted(expected),
            request.url.scheme,
            hdr_host,
            hdr_xf_host,
            hdr_xf_proto,
            hdr_forwarded,
        )
        expected_joined = ",".join(sorted(expected)) or "-"
        raise HTTPException(
            status_code=status.HTTP_403_FORBIDDEN,
            detail=(
                "Cross-origin not allowed "
                f"(origin={origin}, expected={expected_joined}, "
                f"host={hdr_host}, "
                f"x-forwarded-host={hdr_xf_host}, "
                f"x-forwarded-proto={hdr_xf_proto})"
            ),
        )


# ───────────── helpers ─────────────
def _json_pretty(d: dict[str, Any]) -> str:
    return json.dumps(d, indent=2, sort_keys=True, ensure_ascii=False)


def _compute_diff(old_cfg: dict[str, Any], new_cfg: dict[str, Any]) -> str:
    a = _json_pretty(old_cfg).splitlines(keepends=True)
    b = _json_pretty(new_cfg).splitlines(keepends=True)
    diff = difflib.unified_diff(a, b, fromfile="old", tofile="new")
    return "".join(diff)


def _graceful_reload() -> None:
    """
    Prefer a process reload via SIGHUP when running with multiple uvicorn
    workers; fall back to rebuilding the GraphQL schema in-process when
    only a single worker is configured.
    """
    try:
        workers = int(os.getenv("WORKERS", "1"))
    except Exception:
        workers = 1

    if workers <= 1:
        try:
            from gql_app import reload_schema

            reload_schema()
            return
        except Exception:
            # Best-effort; fall through to SIGHUP attempt.
            pass

    try:
        os.kill(os.getppid(), signal.SIGHUP)
    except Exception:
        try:
            from gql_app import reload_schema

            reload_schema()
        except Exception:
            pass


# ───────────────────────────── ROUTES ───────────────────────────────────────
# All /config routes now require a valid Config UI session.
# Mutations additionally require CSRF and same-origin.


[docs] @router.get( "", response_class=JSONResponse, dependencies=[Depends(_require_cfg_session)], ) async def get_config( version: str | None = Query(default=None), ): # pragma: no cover - covered via integration """ Return the requested version (or active) plus a *varsResolved* field. """ raw = ec.get_raw_config(version=version) data: dict[str, Any] = dict(raw) data["varsResolved"] = ec.get_resolved_vars() data["localsEnabled"] = ec.locals_enabled() if not data["localsEnabled"]: data.pop("locals", None) return data
[docs] @router.get( "/schema", response_class=JSONResponse, dependencies=[Depends(_require_cfg_session)], ) async def get_schema(): # pragma: no cover - covered via integration return ec.get_json_schema()
[docs] @router.get( "/versions", response_class=JSONResponse, dependencies=[Depends(_require_cfg_session)], ) async def get_versions(): # pragma: no cover - covered via integration """ List all saved versions, the active version, the currently loaded version, and whether the loader runs in recovery mode. """ return ec.get_versions()
[docs] @router.post( "/activate/{version}", response_class=JSONResponse, dependencies=[ Depends(_require_cfg_session), Depends(_require_csrf), Depends(_require_same_origin), ], ) async def activate_version( version: str, ): # pragma: no cover - covered in higher-level tests """ Activate an existing version. """ try: ec.activate_version(version) _graceful_reload() return {"ok": True, "active": version} except Exception as exc: raise HTTPException(status_code=400, detail=str(exc))
[docs] @router.delete( "/version/{version}", response_class=JSONResponse, dependencies=[ Depends(_require_cfg_session), Depends(_require_csrf), Depends(_require_same_origin), ], ) async def delete_version( version: str, ): # pragma: no cover - covered in higher-level tests """ Permanently delete a non-active version from the store. """ try: ec.delete_version(version) _graceful_reload() return {"ok": True, "deleted": version} except Exception as exc: raise HTTPException(status_code=400, detail=str(exc))
[docs] @router.put( "", response_class=JSONResponse, dependencies=[ Depends(_require_cfg_session), Depends(_require_csrf), Depends(_require_same_origin), ], ) async def save_config( payload: dict[str, Any], ): # pragma: no cover - covered in higher-level tests """ Validate and save a new version to MongoDB, returning the diff against the previous active version, and trigger a graceful reload when `activate` is true. Request shape (both supported): 1) Back-compat: raw endpoints.json object (auto-bumps patch & activates) 2) New: { "config": <endpoints.json>, "version": "1.2.4", # optional, defaults to auto-bump patch "activate": true, # optional, defaults to true "message": "why" # optional } """ try: # Extract new shape vs legacy if "config" in payload and isinstance(payload["config"], dict): new_cfg = payload["config"] requested_ver = payload.get("version") activate = bool(payload.get("activate", True)) message = payload.get("message") else: new_cfg = payload requested_ver = None activate = True message = None # Validate new config validate_payload(new_cfg) # For diff, compare with current active (not the requested target) old_cfg = ec.get_raw_config() # Save (may auto-bump patch if requested_ver is None) ec.save_raw_config( new_cfg, version=requested_ver, activate=activate, message=message ) # Compute diff after saving against previous active diff = _compute_diff(old_cfg, new_cfg) # Reload workers if we activated this new version if activate: _graceful_reload() return {"ok": True, "diff": diff} except ValidationError as exc: raise HTTPException(status_code=400, detail=str(exc)) except Exception as exc: raise HTTPException(status_code=400, detail=str(exc))
[docs] @router.post( "/diff", response_class=JSONResponse, dependencies=[ Depends(_require_cfg_session), Depends(_require_csrf), Depends(_require_same_origin), ], ) async def diff_payload( # pragma: no cover - covered via API tests payload: dict[str, Any] = Body(...), ): """ Compute a unified diff between two payloads: - base: requested version (when provided) or the active version - new : the submitted payload (raw endpoints JSON) Request body accepts either: { "config": <json>, "baseVersion": "1.2.3" } # preferred <json> # legacy shape (no baseVersion) Returns: { ok: true, diff: "<unified diff>" } """ try: if "config" in payload and isinstance(payload["config"], dict): new_cfg = payload["config"] base_version = payload.get("baseVersion") else: new_cfg = payload base_version = None # preview diff without strict validation on purpose old_cfg = ec.get_raw_config(version=base_version) diff = _compute_diff(old_cfg, new_cfg) return {"ok": True, "diff": diff} except Exception as exc: raise HTTPException(status_code=400, detail=str(exc))
[docs] @router.get( "/diff", response_class=JSONResponse, dependencies=[Depends(_require_cfg_session)], ) async def diff_versions( # pragma: no cover - covered via API tests frm: str | None = Query(default=None, alias="from"), to: str | None = Query(default=None, alias="to"), ): """ Compute a unified diff between two SAVED versions. Query params: ?from=1.2.3&to=1.2.4 If 'to' is omitted, compares against the active version. If 'from' is omitted, compares active (old) with 'to' (new). Returns: { ok: true, diff: "<unified diff>" } """ try: left = ec.get_raw_config(version=frm) right = ec.get_raw_config(version=to) diff = _compute_diff(left, right) return {"ok": True, "diff": diff} except Exception as exc: raise HTTPException(status_code=400, detail=str(exc))
[docs] @router.post( "/test-local/{func_name}", response_class=JSONResponse, dependencies=[ Depends(_require_cfg_session), Depends(_require_csrf), Depends(_require_same_origin), ], ) async def test_local( func_name: str, payload: dict[str, Any] = Body(default={}) ): # pragma: no cover - covered via API tests """ Execute a local function declared in endpoints.json; return raw output. """ try: cfg = ec.get_raw_config() if func_name not in (cfg.get("locals") or {}): raise HTTPException( status_code=404, detail=f"Local function '{func_name}' not found", ) fn = ec._import(func_name) args = payload or {} if asyncio.iscoroutinefunction(fn): result = await fn(**args) else: result = fn(**args) return {"ok": True, "result": result} except HTTPException: raise except Exception as exc: raise HTTPException( status_code=500, detail=f"{type(exc).__name__}: {exc}" ) from exc