Source code for ska_aaa_authhelpers.token_scheme

from collections.abc import AsyncGenerator, Iterable
from typing import Annotated

from fastapi import Depends, Header
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.security.base import SecurityBase
from starlette_context import request_cycle_context

from . import validation_rules as vr
from .auth_context import AuthContext
from .jwt import KeysType, decode_jwt
from .roles import Role

access_token = HTTPBearer(
    scheme_name="SKAO Entra ID Authentication",
    description="Access token authorization from Microsoft Entra ID.",
    bearerFormat="Entra ID Access Token (JWT)",
    auto_error=True,
)


def get_auth_context(
    authorization: HTTPAuthorizationCredentials,
    issuers: Iterable[str],
    audience: str,
    keys: KeysType,
    trace: str | None = None,
) -> AuthContext:
    trace = trace if trace is not None else ""
    claims = decode_jwt(authorization.credentials, keys=keys, issuers=issuers, audience=audience)
    groups = frozenset(claims.get("groups", []))
    tkn_roles = frozenset(claims.get("roles", [])).intersection(Role)
    grp_roles = groups.intersection(Role)
    roles = tkn_roles | grp_roles | frozenset({Role.ANY})
    # TODO: Revisit this when we move away from MS Entra...
    if scp := claims.get("scp"):  # User tokens have scopes
        user_id = claims["oid"]
        scopes = frozenset(scp.split(" "))
    else:  # Client credentials aka app tokens do not.
        scopes = frozenset()
        # We treat the app id as our user_id for app2app calls.
        user_id = claims["appid"] if claims.get("ver") == "1.0" else claims["azp"]
    return AuthContext(
        audience=claims["aud"],
        user_id=user_id,
        trace=trace,
        groups=groups,
        scopes=scopes,
        principals=groups.union((user_id,)),
        roles=roles,
        token_claims=claims,
        access_token=authorization.credentials,
    )


[docs] class TokenScheme(SecurityBase): """ Security scheme that generates an AuthContext from bearer token and and enforces basic authorisation rules. """ def __init__( self, *, audience: str, roles: frozenset[Role], scopes: frozenset[str], app_ids: frozenset[str], issuers: Iterable[str], keys: KeysType, ): self.scopes = scopes self.roles = roles self.app_ids = app_ids self.audience = audience self.issuers = issuers self.scheme_name = access_token.scheme_name self.model = access_token.model self.keys = keys async def __call__( self, authorization: Annotated[HTTPAuthorizationCredentials, Depends(access_token)], X_Request_ID: Annotated[str | None, Header()] = None, ) -> AsyncGenerator[AuthContext, None]: # https://starlette-context.readthedocs.io/en/latest/fastapi.html # Create a global request-scoped context object that # we rely on to populate auth context into log messages. # See audit_log_filter.py. ac = get_auth_context( authorization, self.issuers, self.audience, self.keys, trace=X_Request_ID ) with request_cycle_context({"auth": ac}): vr.roles_contain(self.roles, ac.roles) if Role.APP2APP in ac.roles: # For app2pp calls, we check for an exact match of allowed app IDs. vr.app_id_matches(ac.user_id, self.app_ids) else: # For user calls, we check the user has granted this necessary scopes. vr.contains(self.scopes, ac.scopes, label="scopes") yield ac