from typing import Optional from fastapi import Depends, Request from fastapi.security import OAuth2PasswordRequestForm from fastapi_users import BaseUserManager, FastAPIUsers, IntegerIDMixin from fastapi_users.authentication import ( AuthenticationBackend, BearerTransport, JWTStrategy, ) from fastapi_users.db import SQLAlchemyUserDatabase from sqlalchemy.ext.asyncio import AsyncSession from authlib.integrations.starlette_client import OAuth from starlette.config import Config from starlette.middleware.sessions import SessionMiddleware from starlette.responses import Response from pydantic import BaseModel from fastapi.responses import JSONResponse from sqlalchemy import select from .database import get_session from .models import User from .config import settings config = Config('.env') oauth = OAuth(config) oauth.register( name='google', server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', client_kwargs={ 'scope': 'openid email profile', 'redirect_uri': settings.GOOGLE_REDIRECT_URI } ) oauth.register( name='apple', server_metadata_url='https://appleid.apple.com/.well-known/openid-configuration', client_kwargs={ 'scope': 'openid email name', 'redirect_uri': settings.APPLE_REDIRECT_URI } ) class BearerResponseWithRefresh(BaseModel): access_token: str refresh_token: str token_type: str = "bearer" class BearerTransportWithRefresh(BearerTransport): async def get_login_response(self, token: str, refresh_token: str = None) -> Response: if refresh_token: bearer_response = BearerResponseWithRefresh( access_token=token, refresh_token=refresh_token, token_type="bearer" ) else: bearer_response = { "access_token": token, "token_type": "bearer" } return JSONResponse(bearer_response.dict() if hasattr(bearer_response, 'dict') else bearer_response) class AuthenticationBackendWithRefresh(AuthenticationBackend): def __init__( self, name: str, transport: BearerTransportWithRefresh, get_strategy, get_refresh_strategy, ): self.name = name self.transport = transport self.get_strategy = get_strategy self.get_refresh_strategy = get_refresh_strategy async def login(self, strategy, user) -> Response: access_token = await strategy.write_token(user) refresh_strategy = self.get_refresh_strategy() refresh_token = await refresh_strategy.write_token(user) return await self.transport.get_login_response( token=access_token, refresh_token=refresh_token ) async def logout(self, strategy, user, token) -> Response: return await self.transport.get_logout_response() class UserManager(IntegerIDMixin, BaseUserManager[User, int]): reset_password_token_secret = settings.SECRET_KEY verification_token_secret = settings.SECRET_KEY async def on_after_register(self, user: User, request: Optional[Request] = None): print(f"User {user.id} has registered.") async def on_after_forgot_password( self, user: User, token: str, request: Optional[Request] = None ): print(f"User {user.id} has forgot their password. Reset token: {token}") async def on_after_request_verify( self, user: User, token: str, request: Optional[Request] = None ): print(f"Verification requested for user {user.id}. Verification token: {token}") async def on_after_login( self, user: User, request: Optional[Request] = None, response: Optional[Response] = None ): print(f"User {user.id} has logged in.") async def delete(self, user: User, safe: bool = False, request: Optional[Request] = None): """Soft-delete and anonymize the user instead of removing the DB row. This mitigates catastrophic data-loss cascades that can occur when the user row is physically deleted (see TODO issue #3). The record is kept for referential integrity, while all personally identifiable information (PII) is removed and the account is marked inactive. """ # Lazily import to avoid circular deps and heavy imports at startup from datetime import datetime, timezone # Anonymise PII – keep a unique but meaningless email address anonymised_suffix = f"deleted_{user.id}_{int(datetime.now(timezone.utc).timestamp())}" user.email = f"user_{anonymised_suffix}@example.com" user.name = None user.hashed_password = "" user.is_active = False user.is_verified = False user.deleted_at = datetime.now(timezone.utc) user.is_deleted = True # Persist the changes using the underlying user database adapter await self.user_db.update(user) # We purposefully *do not* commit a hard delete, so any FK references # (expenses, lists, etc.) remain intact. return None async def get_user_db(session: AsyncSession = Depends(get_session)): yield SQLAlchemyUserDatabase(session, User) async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)): yield UserManager(user_db) bearer_transport = BearerTransportWithRefresh(tokenUrl="/api/v1/auth/jwt/login") def get_jwt_strategy() -> JWTStrategy: return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60) def get_refresh_jwt_strategy() -> JWTStrategy: return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=7 * 24 * 60 * 60) auth_backend = AuthenticationBackendWithRefresh( name="jwt", transport=bearer_transport, get_strategy=get_jwt_strategy, get_refresh_strategy=get_refresh_jwt_strategy, ) fastapi_users = FastAPIUsers[User, int]( get_user_manager, [auth_backend], ) current_active_user = fastapi_users.current_user(active=True) current_superuser = fastapi_users.current_user(active=True, superuser=True) # --------------------------------------------------------------------------- # JWT helper function used by WebSocket endpoints # --------------------------------------------------------------------------- async def get_user_from_token(token: str, db: AsyncSession) -> Optional[User]: """Return the ``User`` associated with a valid *token* or ``None``. The function decodes the JWT using the same strategy FastAPI Users uses elsewhere in the application. If the token is invalid/expired or the user cannot be found (or is deleted/inactive), ``None`` is returned so the caller can close the WebSocket with the appropriate code. """ strategy = get_jwt_strategy() try: user_id = await strategy.read_token(token) except Exception: # Any decoding/parsing/expiry error – treat as invalid token return None # Fetch the user from the database. We avoid failing hard – return ``None`` # if the user does not exist or is inactive/deleted. result = await db.execute(select(User).where(User.id == user_id)) user = result.scalar_one_or_none() if user is None or getattr(user, "is_deleted", False) or not user.is_active: return None return user