from collections.abc import Callable, Iterable
from typing import Any
from uuid import UUID
from fastapi.params import Security
from .jwt import DEFAULT_ISSUERS, DEFAULT_PUBLIC_KEYS, KeysType
from .roles import Role
from .token_scheme import TokenScheme
[docs]
class SecurityRequires(Security):
def __init__(
self,
dependency: Callable[..., Any] | None,
*,
scopes: Iterable[str],
roles: Iterable[Role],
app_ids: Iterable[str],
use_cache: bool = True,
):
super().__init__(dependency=dependency, use_cache=use_cache, scopes=tuple(scopes))
object.__setattr__(self, "roles", frozenset(roles))
object.__setattr__(self, "app_ids", frozenset(app_ids))
def Requires(
audience: str,
roles: Iterable[Role | str],
scopes: Iterable[str] = (),
app_ids: Iterable[str | UUID] = (),
keys: KeysType = DEFAULT_PUBLIC_KEYS,
issuer: str | Iterable[str] = DEFAULT_ISSUERS,
):
if isinstance(issuer, str):
issuer = (issuer,)
roles_set = frozenset(Role(r) for r in roles)
scopes_set = frozenset(scopes)
app_ids_set = frozenset(map(str, app_ids))
if app_ids_set and Role.APP2APP not in roles_set:
msg = "When setting app_ids, you must include Role.APP2APP in your roles."
raise ValueError(msg)
elif Role.APP2APP in roles_set and not app_ids_set:
msg = (
"When allowing Role.APP2APP, you must set 'app_ids' "
"to say which specific apps may call this endpoint."
)
raise ValueError(msg)
elif Role.APP2APP not in roles_set and not scopes:
msg = "You must set required 'scopes' unless using Role.APP2APP"
raise ValueError(msg)
auth_context_from_token = TokenScheme(
audience=audience,
roles=roles_set,
scopes=scopes_set,
app_ids=app_ids_set,
keys=keys,
issuers=issuer,
)
return SecurityRequires(
auth_context_from_token, roles=roles_set, scopes=scopes_set, app_ids=app_ids_set
)