from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, status, Depends from sqlalchemy.ext.asyncio import AsyncSession from app.auth import get_jwt_strategy, get_user_from_token from app.database import get_transactional_session from app.crud import list as crud_list from app.core.redis import subscribe_to_channel, unsubscribe_from_channel from app.models import User import asyncio router = APIRouter() async def _verify_jwt(token: str): """Return user_id from JWT or None if invalid/expired.""" strategy = get_jwt_strategy() try: # FastAPI Users' JWTStrategy.read_token returns the user ID encoded in the token user_id = await strategy.read_token(token) return user_id except Exception: # pragma: no cover – any decoding/expiration error return None @router.websocket("/ws/lists/{list_id}") async def list_websocket_endpoint( websocket: WebSocket, list_id: int, token: str = Query(...), db: AsyncSession = Depends(get_transactional_session), ): """Authenticated WebSocket endpoint for a specific list.""" user = await get_user_from_token(token, db) if not user: await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return try: # Verify the user has access to this list's group await crud_list.check_list_permission(db, list_id, user.id) except Exception: await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) return await websocket.accept() # Subscribe to the list-specific channel pubsub = await subscribe_to_channel(f"list_{list_id}") try: # Keep the connection alive and forward messages from Redis while True: message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=None) if message and message.get("type") == "message": await websocket.send_text(message["data"]) except WebSocketDisconnect: # Client disconnected pass finally: # Clean up the Redis subscription await unsubscribe_from_channel(f"list_{list_id}", pubsub) @router.websocket("/ws/{household_id}") async def household_websocket_endpoint( websocket: WebSocket, household_id: int, token: str = Query(...), db: AsyncSession = Depends(get_transactional_session) ): """Authenticated WebSocket endpoint for a household.""" user = await get_user_from_token(token, db) if not user: await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return # TODO: Add permission check for household await websocket.accept() # ... rest of household logic