mitlist/be/app/crud/invite.py

188 lines
8.3 KiB
Python

import logging
import secrets
from datetime import datetime, timedelta, timezone
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sqlalchemy import delete
from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError
from typing import Optional
from app.models import Invite as InviteModel, Group as GroupModel, User as UserModel
from app.core.exceptions import (
DatabaseConnectionError,
DatabaseIntegrityError,
DatabaseQueryError,
DatabaseTransactionError,
InviteOperationError
)
logger = logging.getLogger(__name__)
MAX_CODE_GENERATION_ATTEMPTS = 5
async def deactivate_all_active_invites_for_group(db: AsyncSession, group_id: int):
"""Deactivates all currently active invite codes for a specific group."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin():
stmt = (
select(InviteModel)
.where(InviteModel.group_id == group_id, InviteModel.is_active == True)
)
result = await db.execute(stmt)
active_invites = result.scalars().all()
if not active_invites:
return
for invite in active_invites:
invite.is_active = False
db.add(invite)
await db.flush()
except OperationalError as e:
logger.error(f"Database connection error deactivating invites for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error deactivating invites for group {group_id}: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error deactivating invites for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error deactivating invites for group {group_id}: {str(e)}")
async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 365 * 100) -> Optional[InviteModel]:
"""Creates a new invite code for a group, deactivating any existing active ones for that group first."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin():
await deactivate_all_active_invites_for_group(db, group_id)
expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days)
potential_code = None
for attempt in range(MAX_CODE_GENERATION_ATTEMPTS):
potential_code = secrets.token_urlsafe(16)
existing_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
existing_result = await db.execute(existing_check_stmt)
if existing_result.scalar_one_or_none() is None:
break
if attempt == MAX_CODE_GENERATION_ATTEMPTS - 1:
raise InviteOperationError("Failed to generate a unique invite code after several attempts.")
final_check_stmt = select(InviteModel.id).where(InviteModel.code == potential_code, InviteModel.is_active == True).limit(1)
final_check_result = await db.execute(final_check_stmt)
if final_check_result.scalar_one_or_none() is not None:
raise InviteOperationError("Invite code collision detected just before creation attempt.")
db_invite = InviteModel(
code=potential_code,
group_id=group_id,
created_by_id=creator_id,
expires_at=expires_at,
is_active=True
)
db.add(db_invite)
await db.flush()
stmt = (
select(InviteModel)
.where(InviteModel.id == db_invite.id)
.options(
selectinload(InviteModel.group),
selectinload(InviteModel.creator)
)
)
result = await db.execute(stmt)
loaded_invite = result.scalar_one_or_none()
if loaded_invite is None:
raise InviteOperationError("Failed to load invite after creation and flush.")
return loaded_invite
except InviteOperationError:
raise
except IntegrityError as e:
logger.error(f"Database integrity error during invite creation for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseIntegrityError(f"Failed to create invite due to DB integrity issue: {str(e)}")
except OperationalError as e:
logger.error(f"Database connection error during invite creation for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error during invite creation: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error during invite creation for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error during invite creation: {str(e)}")
async def get_active_invite_for_group(db: AsyncSession, group_id: int) -> Optional[InviteModel]:
"""Gets the currently active and non-expired invite for a specific group."""
now = datetime.now(timezone.utc)
try:
stmt = (
select(InviteModel).where(
InviteModel.group_id == group_id,
InviteModel.is_active == True,
InviteModel.expires_at > now
)
.limit(1)
.options(
selectinload(InviteModel.group),
selectinload(InviteModel.creator)
)
)
result = await db.execute(stmt)
return result.scalars().first()
except OperationalError as e:
logger.error(f"Database connection error fetching active invite for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error fetching active invite for group {group_id}: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"DB query error fetching active invite for group {group_id}: {str(e)}", exc_info=True)
raise DatabaseQueryError(f"DB query error fetching active invite for group {group_id}: {str(e)}")
async def get_active_invite_by_code(db: AsyncSession, code: str) -> Optional[InviteModel]:
"""Gets an active and non-expired invite by its code."""
now = datetime.now(timezone.utc)
try:
stmt = (
select(InviteModel).where(
InviteModel.code == code,
InviteModel.is_active == True,
InviteModel.expires_at > now
)
.options(
selectinload(InviteModel.group),
selectinload(InviteModel.creator)
)
)
result = await db.execute(stmt)
return result.scalars().first()
except OperationalError as e:
raise DatabaseConnectionError(f"DB connection error fetching invite: {str(e)}")
except SQLAlchemyError as e:
raise DatabaseQueryError(f"DB query error fetching invite: {str(e)}")
async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteModel:
"""Marks an invite as inactive (used) and reloads with relationships."""
try:
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
invite.is_active = False
db.add(invite)
await db.flush()
stmt = (
select(InviteModel)
.where(InviteModel.id == invite.id)
.options(
selectinload(InviteModel.group),
selectinload(InviteModel.creator)
)
)
result = await db.execute(stmt)
updated_invite = result.scalar_one_or_none()
if updated_invite is None:
raise InviteOperationError("Failed to load invite after deactivation.")
return updated_invite
except OperationalError as e:
logger.error(f"Database connection error deactivating invite: {str(e)}", exc_info=True)
raise DatabaseConnectionError(f"DB connection error deactivating invite: {str(e)}")
except SQLAlchemyError as e:
logger.error(f"Unexpected SQLAlchemy error deactivating invite: {str(e)}", exc_info=True)
raise DatabaseTransactionError(f"DB transaction error deactivating invite: {str(e)}")