from collections.abc import AsyncGenerator, Iterable
from typing import Annotated
from fastapi import Depends, Header
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.security.base import SecurityBase
from starlette_context import request_cycle_context
from . import validation_rules as vr
from .auth_context import AuthContext
from .jwt import KeysType, decode_jwt
from .roles import Role
access_token = HTTPBearer(
scheme_name="SKAO Entra ID Authentication",
description="Access token authorization from Microsoft Entra ID.",
bearerFormat="Entra ID Access Token (JWT)",
auto_error=True,
)
def get_auth_context(
authorization: HTTPAuthorizationCredentials,
issuers: Iterable[str],
audience: str,
keys: KeysType,
trace: str | None = None,
) -> AuthContext:
trace = trace if trace is not None else ""
claims = decode_jwt(authorization.credentials, keys=keys, issuers=issuers, audience=audience)
groups = frozenset(claims.get("groups", []))
tkn_roles = frozenset(claims.get("roles", [])).intersection(Role)
grp_roles = groups.intersection(Role)
roles = tkn_roles | grp_roles | frozenset({Role.ANY})
# TODO: Revisit this when we move away from MS Entra...
if scp := claims.get("scp"): # User tokens have scopes
user_id = claims["oid"]
scopes = frozenset(scp.split(" "))
else: # Client credentials aka app tokens do not.
scopes = frozenset()
# We treat the app id as our user_id for app2app calls.
user_id = claims["appid"] if claims.get("ver") == "1.0" else claims["azp"]
return AuthContext(
audience=claims["aud"],
user_id=user_id,
trace=trace,
groups=groups,
scopes=scopes,
principals=groups.union((user_id,)),
roles=roles,
token_claims=claims,
access_token=authorization.credentials,
)
[docs]
class TokenScheme(SecurityBase):
"""
Security scheme that generates
an AuthContext from bearer token and
and enforces basic authorisation rules.
"""
def __init__(
self,
*,
audience: str,
roles: frozenset[Role],
scopes: frozenset[str],
app_ids: frozenset[str],
issuers: Iterable[str],
keys: KeysType,
):
self.scopes = scopes
self.roles = roles
self.app_ids = app_ids
self.audience = audience
self.issuers = issuers
self.scheme_name = access_token.scheme_name
self.model = access_token.model
self.keys = keys
async def __call__(
self,
authorization: Annotated[HTTPAuthorizationCredentials, Depends(access_token)],
X_Request_ID: Annotated[str | None, Header()] = None,
) -> AsyncGenerator[AuthContext, None]:
# https://starlette-context.readthedocs.io/en/latest/fastapi.html
# Create a global request-scoped context object that
# we rely on to populate auth context into log messages.
# See audit_log_filter.py.
ac = get_auth_context(
authorization, self.issuers, self.audience, self.keys, trace=X_Request_ID
)
with request_cycle_context({"auth": ac}):
vr.roles_contain(self.roles, ac.roles)
if Role.APP2APP in ac.roles:
# For app2pp calls, we check for an exact match of allowed app IDs.
vr.app_id_matches(ac.user_id, self.app_ids)
else:
# For user calls, we check the user has granted this necessary scopes.
vr.contains(self.scopes, ac.scopes, label="scopes")
yield ac