from typing import List, Optional from decimal import Decimal from datetime import datetime, timezone from sqlalchemy import select, func, update, delete from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload, joinedload from app.models import ( SettlementActivity, ExpenseSplit, Expense, User, ExpenseSplitStatusEnum, ExpenseOverallStatusEnum, ) from pydantic import BaseModel from app.crud.audit import create_financial_audit_log from app.schemas.settlement_activity import SettlementActivityCreate from app.core.exceptions import UserNotFoundError, InvalidOperationError, FinancialConflictError class SettlementActivityCreatePlaceholder(BaseModel): expense_split_id: int paid_by_user_id: int amount_paid: Decimal paid_at: Optional[datetime] = None class Config: orm_mode = True async def update_expense_split_status(db: AsyncSession, expense_split_id: int) -> Optional[ExpenseSplit]: """ Updates the status of an ExpenseSplit based on its settlement activities. Also updates the overall status of the parent Expense. """ result = await db.execute( select(ExpenseSplit) .options( selectinload(ExpenseSplit.settlement_activities), joinedload(ExpenseSplit.expense) # To get expense_id easily ) .where(ExpenseSplit.id == expense_split_id) ) expense_split = result.scalar_one_or_none() if not expense_split: return None total_paid = sum(activity.amount_paid for activity in expense_split.settlement_activities) total_paid = Decimal(total_paid).quantize(Decimal("0.01")) if total_paid >= expense_split.owed_amount: expense_split.status = ExpenseSplitStatusEnum.paid latest_paid_at = None if expense_split.settlement_activities: latest_paid_at = max(act.paid_at for act in expense_split.settlement_activities if act.paid_at) expense_split.paid_at = latest_paid_at if latest_paid_at else datetime.now(timezone.utc) elif total_paid > 0: expense_split.status = ExpenseSplitStatusEnum.partially_paid expense_split.paid_at = None else: # total_paid == 0 expense_split.status = ExpenseSplitStatusEnum.unpaid expense_split.paid_at = None await db.flush() await db.refresh(expense_split, attribute_names=['status', 'paid_at', 'expense']) return expense_split async def update_expense_overall_status(db: AsyncSession, expense_id: int) -> Optional[Expense]: """ Updates the overall_status of an Expense based on the status of its splits. """ result = await db.execute( select(Expense).options(selectinload(Expense.splits)).where(Expense.id == expense_id) ) expense = result.scalar_one_or_none() if not expense: return None if not expense.splits: expense.overall_settlement_status = ExpenseOverallStatusEnum.unpaid await db.flush() await db.refresh(expense) return expense num_splits = len(expense.splits) num_paid_splits = 0 num_partially_paid_splits = 0 num_unpaid_splits = 0 for split in expense.splits: if split.status == ExpenseSplitStatusEnum.paid: num_paid_splits += 1 elif split.status == ExpenseSplitStatusEnum.partially_paid: num_partially_paid_splits += 1 else: num_unpaid_splits += 1 if num_paid_splits == num_splits: expense.overall_settlement_status = ExpenseOverallStatusEnum.paid elif num_unpaid_splits == num_splits: expense.overall_settlement_status = ExpenseOverallStatusEnum.unpaid else: expense.overall_settlement_status = ExpenseOverallStatusEnum.partially_paid await db.flush() await db.refresh(expense, attribute_names=['overall_settlement_status']) return expense async def create_settlement_activity( db: AsyncSession, settlement_activity_in: SettlementActivityCreate, current_user_id: int ) -> SettlementActivity: """ Creates a new settlement activity, then updates the parent expense split and expense statuses. Uses pessimistic locking on the ExpenseSplit row to prevent race conditions. Relies on the calling context (e.g., transactional session dependency) for the transaction. """ # Lock the expense split row for the duration of the transaction split_stmt = ( select(ExpenseSplit) .where(ExpenseSplit.id == settlement_activity_in.expense_split_id) .with_for_update() ) split_result = await db.execute(split_stmt) expense_split = split_result.scalar_one_or_none() if not expense_split: raise InvalidOperationError(f"Expense split with ID {settlement_activity_in.expense_split_id} not found.") # Check if the split is already fully paid if expense_split.status == ExpenseSplitStatusEnum.paid: raise FinancialConflictError(f"Expense split {expense_split.id} is already fully paid.") # Validate that the user paying exists user_result = await db.execute(select(User).where(User.id == settlement_activity_in.paid_by_user_id)) if not user_result.scalar_one_or_none(): raise UserNotFoundError(user_id=settlement_activity_in.paid_by_user_id) db_settlement_activity = SettlementActivity( expense_split_id=settlement_activity_in.expense_split_id, paid_by_user_id=settlement_activity_in.paid_by_user_id, amount_paid=settlement_activity_in.amount_paid, paid_at=settlement_activity_in.paid_at if settlement_activity_in.paid_at else datetime.now(timezone.utc), created_by_user_id=current_user_id ) db.add(db_settlement_activity) await db.flush() await create_financial_audit_log( db=db, user_id=current_user_id, action_type="SETTLEMENT_ACTIVITY_CREATED", entity=db_settlement_activity, ) # Update statuses within the same transaction updated_split = await update_expense_split_status(db, expense_split_id=db_settlement_activity.expense_split_id) if updated_split and updated_split.expense_id: await update_expense_overall_status(db, expense_id=updated_split.expense_id) # Re-fetch the object with all relationships loaded to prevent lazy-loading issues during serialization stmt = ( select(SettlementActivity) .where(SettlementActivity.id == db_settlement_activity.id) .options( selectinload(SettlementActivity.payer), selectinload(SettlementActivity.creator) ) ) result = await db.execute(stmt) loaded_activity = result.scalar_one_or_none() if not loaded_activity: # This should not happen in a normal flow raise InvalidOperationError("Failed to load settlement activity after creation.") return loaded_activity async def get_settlement_activity_by_id( db: AsyncSession, settlement_activity_id: int ) -> Optional[SettlementActivity]: """ Fetches a single SettlementActivity by its ID, loading relationships. """ result = await db.execute( select(SettlementActivity) .options( selectinload(SettlementActivity.split).selectinload(ExpenseSplit.expense), selectinload(SettlementActivity.payer), selectinload(SettlementActivity.creator) ) .where(SettlementActivity.id == settlement_activity_id) ) return result.scalar_one_or_none() async def get_settlement_activities_for_split( db: AsyncSession, expense_split_id: int, skip: int = 0, limit: int = 100 ) -> List[SettlementActivity]: """ Fetches a list of SettlementActivity records associated with a given expense_split_id. """ result = await db.execute( select(SettlementActivity) .where(SettlementActivity.expense_split_id == expense_split_id) .options( selectinload(SettlementActivity.payer), selectinload(SettlementActivity.creator) ) .order_by(SettlementActivity.paid_at.desc(), SettlementActivity.created_at.desc()) .offset(skip) .limit(limit) ) return result.scalars().all() # Further CRUD operations like update/delete can be added later if needed.