from typing import List, Optional from decimal import Decimal from datetime import datetime, timezone from sqlalchemy import select, func, update, delete from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload, joinedload from app.models import ( SettlementActivity, ExpenseSplit, Expense, User, ExpenseSplitStatusEnum, ExpenseOverallStatusEnum, ) from pydantic import BaseModel class SettlementActivityCreatePlaceholder(BaseModel): expense_split_id: int paid_by_user_id: int amount_paid: Decimal paid_at: Optional[datetime] = None class Config: orm_mode = True async def update_expense_split_status(db: AsyncSession, expense_split_id: int) -> Optional[ExpenseSplit]: """ Updates the status of an ExpenseSplit based on its settlement activities. Also updates the overall status of the parent Expense. """ result = await db.execute( select(ExpenseSplit) .options( selectinload(ExpenseSplit.settlement_activities), joinedload(ExpenseSplit.expense) # To get expense_id easily ) .where(ExpenseSplit.id == expense_split_id) ) expense_split = result.scalar_one_or_none() if not expense_split: return None total_paid = sum(activity.amount_paid for activity in expense_split.settlement_activities) total_paid = Decimal(total_paid).quantize(Decimal("0.01")) if total_paid >= expense_split.owed_amount: expense_split.status = ExpenseSplitStatusEnum.paid latest_paid_at = None if expense_split.settlement_activities: latest_paid_at = max(act.paid_at for act in expense_split.settlement_activities if act.paid_at) expense_split.paid_at = latest_paid_at if latest_paid_at else datetime.now(timezone.utc) elif total_paid > 0: expense_split.status = ExpenseSplitStatusEnum.partially_paid expense_split.paid_at = None else: # total_paid == 0 expense_split.status = ExpenseSplitStatusEnum.unpaid expense_split.paid_at = None await db.flush() await db.refresh(expense_split, attribute_names=['status', 'paid_at', 'expense']) return expense_split async def update_expense_overall_status(db: AsyncSession, expense_id: int) -> Optional[Expense]: """ Updates the overall_status of an Expense based on the status of its splits. """ result = await db.execute( select(Expense).options(selectinload(Expense.splits)).where(Expense.id == expense_id) ) expense = result.scalar_one_or_none() if not expense: return None if not expense.splits: expense.overall_settlement_status = ExpenseOverallStatusEnum.unpaid await db.flush() await db.refresh(expense) return expense num_splits = len(expense.splits) num_paid_splits = 0 num_partially_paid_splits = 0 num_unpaid_splits = 0 for split in expense.splits: if split.status == ExpenseSplitStatusEnum.paid: num_paid_splits += 1 elif split.status == ExpenseSplitStatusEnum.partially_paid: num_partially_paid_splits += 1 else: num_unpaid_splits += 1 if num_paid_splits == num_splits: expense.overall_settlement_status = ExpenseOverallStatusEnum.paid elif num_unpaid_splits == num_splits: expense.overall_settlement_status = ExpenseOverallStatusEnum.unpaid else: expense.overall_settlement_status = ExpenseOverallStatusEnum.partially_paid await db.flush() await db.refresh(expense, attribute_names=['overall_settlement_status']) return expense async def create_settlement_activity( db: AsyncSession, settlement_activity_in: SettlementActivityCreatePlaceholder, current_user_id: int ) -> Optional[SettlementActivity]: """ Creates a new settlement activity, then updates the parent expense split and expense statuses. """ split_result = await db.execute(select(ExpenseSplit).where(ExpenseSplit.id == settlement_activity_in.expense_split_id)) expense_split = split_result.scalar_one_or_none() if not expense_split: return None user_result = await db.execute(select(User).where(User.id == settlement_activity_in.paid_by_user_id)) paid_by_user = user_result.scalar_one_or_none() if not paid_by_user: return None # User not found db_settlement_activity = SettlementActivity( expense_split_id=settlement_activity_in.expense_split_id, paid_by_user_id=settlement_activity_in.paid_by_user_id, amount_paid=settlement_activity_in.amount_paid, paid_at=settlement_activity_in.paid_at if settlement_activity_in.paid_at else datetime.now(timezone.utc), created_by_user_id=current_user_id ) db.add(db_settlement_activity) await db.flush() # Update statuses updated_split = await update_expense_split_status(db, expense_split_id=db_settlement_activity.expense_split_id) if updated_split and updated_split.expense_id: await update_expense_overall_status(db, expense_id=updated_split.expense_id) else: pass return db_settlement_activity async def get_settlement_activity_by_id( db: AsyncSession, settlement_activity_id: int ) -> Optional[SettlementActivity]: """ Fetches a single SettlementActivity by its ID, loading relationships. """ result = await db.execute( select(SettlementActivity) .options( selectinload(SettlementActivity.split).selectinload(ExpenseSplit.expense), selectinload(SettlementActivity.payer), selectinload(SettlementActivity.creator) ) .where(SettlementActivity.id == settlement_activity_id) ) return result.scalar_one_or_none() async def get_settlement_activities_for_split( db: AsyncSession, expense_split_id: int, skip: int = 0, limit: int = 100 ) -> List[SettlementActivity]: """ Fetches a list of SettlementActivity records associated with a given expense_split_id. """ result = await db.execute( select(SettlementActivity) .where(SettlementActivity.expense_split_id == expense_split_id) .options( selectinload(SettlementActivity.payer), selectinload(SettlementActivity.creator) ) .order_by(SettlementActivity.paid_at.desc(), SettlementActivity.created_at.desc()) .offset(skip) .limit(limit) ) return result.scalars().all() # Further CRUD operations like update/delete can be added later if needed.