Implement logout and token blacklisting to revoke JWT tokens
JWT tokens are stateless by design, meaning they remain valid until they expire. Token revocation allows you to invalidate tokens before their expiration, which is essential for implementing logout functionality and handling compromised tokens.
# Simple set for development (lost on restart)TOKEN_BLOCKLIST: set[str] = set()def check_if_token_revoked(token: str) -> bool: """Check if a token is in the blocklist. Args: token: The raw JWT token string Returns: True if the token is revoked, False otherwise """ return token in TOKEN_BLOCKLIST
Performance tip: Use Redis or another fast key-value store for production. Database queries add latency to every authenticated request.
from fastapi import FastAPIfrom authx import AuthX, AuthXConfigapp = FastAPI()auth_config = AuthXConfig( JWT_ALGORITHM="HS256", JWT_SECRET_KEY="your-secret-key", JWT_TOKEN_LOCATION=["headers"],)auth = AuthX(config=auth_config)auth.handle_errors(app)# Storage for revoked tokensTOKEN_BLOCKLIST: set[str] = set()# Define the callback functiondef check_if_token_revoked(token: str) -> bool: """Check if a token is revoked. This function is called by AuthX for every protected request. Return True if the token should be rejected. """ return token in TOKEN_BLOCKLIST# Register the callback with AuthXauth.set_token_blocklist(check_if_token_revoked)
The callback receives the raw token string. You can decode it to access claims like JTI if needed, but storing the full token is simpler for blocklisting.
Create a logout endpoint that revokes the current token:
from fastapi import Request, HTTPException@app.post("/logout")async def logout(request: Request): """Logout endpoint that revokes the current token.""" # Get the token from the request token = await auth.get_access_token_from_request(request) # Verify it's valid before revoking payload = auth.verify_token(token) # Add the token to the blocklist TOKEN_BLOCKLIST.add(token.token) return {"message": "Successfully logged out"}
Once a token is in the blocklist, AuthX automatically rejects it:
from authx.schema import TokenPayload@app.get("/protected")async def protected_route(payload: TokenPayload = auth.ACCESS_REQUIRED): """Protected route that checks token revocation automatically.""" # If the token is in the blocklist, this code is never reached # AuthX raises RevokedTokenError (401) before entering the route return { "message": "Access granted", "username": payload.sub }
# Store tokens by userUSER_TOKENS: dict[str, set[str]] = {}@app.post("/login")def login(data: LoginRequest): """Login and track tokens per user.""" access_token = auth.create_access_token(uid=data.username) # Track this token for the user if data.username not in USER_TOKENS: USER_TOKENS[data.username] = set() USER_TOKENS[data.username].add(access_token) return {"access_token": access_token, "token_type": "bearer"}@app.post("/logout-all")async def logout_all_devices(payload: TokenPayload = auth.ACCESS_REQUIRED): """Logout from all devices by revoking all user tokens.""" username = payload.sub # Revoke all tokens for this user if username in USER_TOKENS: for token in USER_TOKENS[username]: TOKEN_BLOCKLIST.add(token) USER_TOKENS[username].clear() return {"message": "Logged out from all devices"}
Store only the JTI (JWT ID) for more efficient storage:
# Store JTIs instead of full tokensREVOKED_JTIS: set[str] = set()def check_if_token_revoked(token: str) -> bool: """Check if a token's JTI is revoked.""" # Decode token to get JTI (without verification for blocklist check) from authx.token import decode_token try: payload = decode_token(token, key=auth.config.public_key, verify=False) jti = payload.get("jti") return jti in REVOKED_JTIS except: return False# Register the callbackauth.set_token_blocklist(check_if_token_revoked)@app.post("/logout")async def logout(request: Request): """Logout by revoking the token's JTI.""" token = await auth.get_access_token_from_request(request) payload = auth.verify_token(token) # Add JTI to revoked set REVOKED_JTIS.add(payload.jti) return {"message": "Successfully logged out"}
Using JTI is more efficient as it stores only a small identifier instead of the full token string. However, ensure your JTIs are properly generated (AuthX does this automatically using UUIDs).
Expired tokens should be removed from the blocklist to save space:
import timefrom typing import Dict# Store tokens with expiration timesBLOCKLIST: Dict[str, int] = {} # token -> expiration_timestampdef check_if_token_revoked(token: str) -> bool: """Check if token is revoked and not expired.""" if token not in BLOCKLIST: return False # Check if blocklist entry has expired if BLOCKLIST[token] < time.time(): # Token expiration passed, remove from blocklist del BLOCKLIST[token] return False return Truedef revoke_token_with_expiry(token: str, payload: TokenPayload): """Add token to blocklist with its expiration time.""" exp_timestamp = payload.exp BLOCKLIST[token] = exp_timestamp# Periodic cleanup taskimport asyncioasync def cleanup_blocklist(): """Remove expired entries from blocklist.""" while True: await asyncio.sleep(3600) # Run every hour current_time = time.time() expired = [t for t, exp in BLOCKLIST.items() if exp < current_time] for token in expired: del BLOCKLIST[token]