# app/crud/invite.py import secrets from datetime import datetime, timedelta, timezone from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload # Ensure selectinload is imported from sqlalchemy import delete # Import delete statement from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError from typing import Optional from app.models import Invite as InviteModel, Group as GroupModel, User as UserModel # Import related models for selectinload from app.core.exceptions import ( DatabaseConnectionError, DatabaseIntegrityError, DatabaseQueryError, DatabaseTransactionError, InviteOperationError # Add new specific exception ) # Invite codes should be reasonably unique, but handle potential collision MAX_CODE_GENERATION_ATTEMPTS = 5 async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 7) -> Optional[InviteModel]: """Creates a new invite code for a group.""" expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days) potential_code = None for attempt in range(MAX_CODE_GENERATION_ATTEMPTS): potential_code = secrets.token_urlsafe(16) # Check if an *active* invite with this code already exists (outside main transaction for now) # Ideally, unique constraint on (code, is_active=true) in DB and catch IntegrityError. # This check reduces collision chance before attempting transaction. existing_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1) existing_result = await db.execute(existing_check_stmt) if existing_result.scalar_one_or_none() is None: break # Found a potentially unique code if attempt == MAX_CODE_GENERATION_ATTEMPTS - 1: raise InviteOperationError("Failed to generate a unique invite code after several attempts.") try: async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Final check within transaction to be absolutely sure before insert final_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1) final_check_result = await db.execute(final_check_stmt) if final_check_result.scalar_one_or_none() is not None: # Extremely unlikely if previous check passed, but handles race condition await transaction.rollback() raise InviteOperationError("Invite code collision detected during transaction.") db_invite = InviteModel( code=potential_code, group_id=group_id, created_by_id=creator_id, expires_at=expires_at, is_active=True ) db.add(db_invite) await db.flush() # Assigns ID # Re-fetch with relationships stmt = ( select(InviteModel) .where(InviteModel.id == db_invite.id) .options( selectinload(InviteModel.group), selectinload(InviteModel.creator) ) ) result = await db.execute(stmt) loaded_invite = result.scalar_one_or_none() if loaded_invite is None: await transaction.rollback() raise InviteOperationError("Failed to load invite after creation.") await transaction.commit() return loaded_invite except IntegrityError as e: # Catch if DB unique constraint on code is violated # Rollback handled by context manager raise DatabaseIntegrityError(f"Failed to create invite due to DB integrity: {str(e)}") except OperationalError as e: raise DatabaseConnectionError(f"DB connection error during invite creation: {str(e)}") except SQLAlchemyError as e: raise DatabaseTransactionError(f"DB transaction error during invite creation: {str(e)}") async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]: """Gets an active and non-expired invite by its code.""" now = datetime.now(timezone.utc) try: stmt = ( select(InviteModel).where( InviteModel.code == code, InviteModel.is_active == True, InviteModel.expires_at > now ) .options( selectinload(InviteModel.group), selectinload(InviteModel.creator) ) ) result = await db.execute(stmt) return result.scalars().first() except OperationalError as e: raise DatabaseConnectionError(f"DB connection error fetching invite: {str(e)}") except SQLAlchemyError as e: raise DatabaseQueryError(f"DB query error fetching invite: {str(e)}") async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteModel: """Marks an invite as inactive (used) and reloads with relationships.""" try: async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: invite.is_active = False db.add(invite) # Add to session to track change await db.flush() # Persist is_active change # Re-fetch with relationships stmt = ( select(InviteModel) .where(InviteModel.id == invite.id) .options( selectinload(InviteModel.group), selectinload(InviteModel.creator) ) ) result = await db.execute(stmt) updated_invite = result.scalar_one_or_none() if updated_invite is None: # Should not happen as invite is passed in await transaction.rollback() raise InviteOperationError("Failed to load invite after deactivation.") await transaction.commit() return updated_invite except OperationalError as e: raise DatabaseConnectionError(f"DB connection error deactivating invite: {str(e)}") except SQLAlchemyError as e: raise DatabaseTransactionError(f"DB transaction error deactivating invite: {str(e)}") # Ensure InviteOperationError is defined in app.core.exceptions # Example: class InviteOperationError(AppException): pass # Optional: Function to periodically delete old, inactive invites # async def cleanup_old_invites(db: AsyncSession, older_than_days: int = 30): ...