from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, status, Depends from sqlalchemy.ext.asyncio import AsyncSession from app.auth import get_jwt_strategy, get_user_from_token from app.database import get_transactional_session from app.crud import list as crud_list from app.core.redis import subscribe_to_channel, unsubscribe_from_channel from app.models import User import asyncio router = APIRouter() async def _verify_jwt(token: str): """Return user_id from JWT or None if invalid/expired.""" strategy = get_jwt_strategy() try: # FastAPI Users' JWTStrategy.read_token returns the user ID encoded in the token user_id = await strategy.read_token(token) return user_id except Exception: # pragma: no cover – any decoding/expiration error return None @router.websocket("/ws/lists/{list_id}") async def list_websocket_endpoint( websocket: WebSocket, list_id: int, token: str = Query(...), db: AsyncSession = Depends(get_transactional_session), ): """Authenticated WebSocket endpoint for a specific list.""" user = await get_user_from_token(token, db) if not user: await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return try: # Verify the user has access to this list's group await crud_list.check_list_permission(db, list_id, user.id) except Exception: await websocket.close(code=status.WS_1003_UNSUPPORTED_DATA) return await websocket.accept() # Temporary: Test connection without Redis try: print(f"WebSocket connected for list {list_id}, user {user.id}") # Send a test message await websocket.send_text('{"event": "connected", "payload": {"message": "WebSocket connected successfully"}}') # Keep connection alive while True: await asyncio.sleep(10) # Send periodic ping to keep connection alive await websocket.send_text('{"event": "ping", "payload": {}}') except WebSocketDisconnect: print(f"WebSocket disconnected for list {list_id}, user {user.id}") pass except Exception as e: print(f"WebSocket error for list {list_id}, user {user.id}: {e}") pass @router.websocket("/ws/{household_id}") async def household_websocket_endpoint( websocket: WebSocket, household_id: int, token: str = Query(...), db: AsyncSession = Depends(get_transactional_session) ): """Authenticated WebSocket endpoint for a household.""" user = await get_user_from_token(token, db) if not user: await websocket.close(code=status.WS_1008_POLICY_VIOLATION) return # TODO: Add permission check for household await websocket.accept() # ... rest of household logic