# be/app/services/costs_service.py import logging from decimal import Decimal, ROUND_HALF_UP, ROUND_DOWN from typing import List from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.models import ( User as UserModel, Group as GroupModel, List as ListModel, Expense as ExpenseModel, Item as ItemModel, UserGroup as UserGroupModel, SplitTypeEnum, ExpenseSplit as ExpenseSplitModel, SettlementActivity as SettlementActivityModel, Settlement as SettlementModel ) from app.schemas.cost import ListCostSummary, GroupBalanceSummary, UserCostShare, UserBalanceDetail, SuggestedSettlement from app.schemas.expense import ExpenseCreate, ExpensePublic from app.crud import list as crud_list from app.crud import expense as crud_expense from app.core.exceptions import ListNotFoundError, ListPermissionError, GroupNotFoundError, GroupPermissionError, InvalidOperationError logger = logging.getLogger(__name__) def calculate_suggested_settlements(user_balances: List[UserBalanceDetail]) -> List[SuggestedSettlement]: """ Calculate suggested settlements to balance the finances within a group. This function takes the current balances of all users and suggests optimal settlements to minimize the number of transactions needed to settle all debts. Args: user_balances: List of UserBalanceDetail objects with their current balances Returns: List of SuggestedSettlement objects representing the suggested payments """ debtors = [] creditors = [] epsilon = Decimal('0.01') for user in user_balances: if abs(user.net_balance) < epsilon: continue if user.net_balance < Decimal('0'): debtors.append({ 'user_id': user.user_id, 'user_identifier': user.user_identifier, 'amount': -user.net_balance }) else: creditors.append({ 'user_id': user.user_id, 'user_identifier': user.user_identifier, 'amount': user.net_balance }) debtors.sort(key=lambda x: x['amount'], reverse=True) creditors.sort(key=lambda x: x['amount'], reverse=True) settlements = [] while debtors and creditors: debtor = debtors[0] creditor = creditors[0] amount = min(debtor['amount'], creditor['amount']).quantize(Decimal('0.01'), rounding=ROUND_HALF_UP) if amount > Decimal('0'): settlements.append( SuggestedSettlement( from_user_id=debtor['user_id'], from_user_identifier=debtor['user_identifier'], to_user_id=creditor['user_id'], to_user_identifier=creditor['user_identifier'], amount=amount ) ) debtor['amount'] -= amount creditor['amount'] -= amount if debtor['amount'] < epsilon: debtors.pop(0) if creditor['amount'] < epsilon: creditors.pop(0) return settlements async def get_list_cost_summary_logic( db: AsyncSession, list_id: int, current_user_id: int ) -> ListCostSummary: """ Core logic to retrieve a calculated cost summary for a specific list. This version does NOT create an expense if one is not found. """ await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user_id) list_result = await db.execute( select(ListModel) .options( selectinload(ListModel.items).options(selectinload(ItemModel.added_by_user)), selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))), selectinload(ListModel.creator) ) .where(ListModel.id == list_id) ) db_list = list_result.scalars().first() if not db_list: raise ListNotFoundError(list_id) expense_result = await db.execute( select(ExpenseModel) .where(ExpenseModel.list_id == list_id) .options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user))) ) db_expense = expense_result.scalars().first() total_list_cost = sum(item.price for item in db_list.items if item.price is not None and item.price > Decimal("0")) # If no expense exists or no items with cost, return a summary based on item prices alone. if not db_expense: return ListCostSummary( list_id=db_list.id, list_name=db_list.name, total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), num_participating_users=0, equal_share_per_user=Decimal("0.00"), user_balances=[] ) # --- Calculation logic based on existing expense --- participating_users = set() user_items_added_value = {} for item in db_list.items: if item.price is not None and item.price > Decimal("0") and item.added_by_user: participating_users.add(item.added_by_user) user_items_added_value[item.added_by_user.id] = user_items_added_value.get(item.added_by_user.id, Decimal("0.00")) + item.price for split in db_expense.splits: if split.user: participating_users.add(split.user) num_participating_users = len(participating_users) if num_participating_users == 0: return ListCostSummary( list_id=db_list.id, list_name=db_list.name, total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), num_participating_users=0, equal_share_per_user=Decimal("0.00"), user_balances=[] ) equal_share_per_user_for_response = (db_expense.total_amount / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) sorted_participating_users = sorted(list(participating_users), key=lambda u: u.id) user_final_shares = {} if num_participating_users > 0: base_share_unrounded = db_expense.total_amount / Decimal(num_participating_users) for user in sorted_participating_users: user_final_shares[user.id] = base_share_unrounded.quantize(Decimal("0.01"), rounding=ROUND_DOWN) sum_of_rounded_shares = sum(user_final_shares.values()) remaining_pennies = int(((db_expense.total_amount - sum_of_rounded_shares) * Decimal("100")).to_integral_value(rounding=ROUND_HALF_UP)) for i in range(remaining_pennies): user_to_adjust = sorted_participating_users[i % num_participating_users] user_final_shares[user_to_adjust.id] += Decimal("0.01") user_balances = [] for user in sorted_participating_users: items_added = user_items_added_value.get(user.id, Decimal("0.00")) current_user_share = user_final_shares.get(user.id, Decimal("0.00")) balance = items_added - current_user_share user_identifier = user.name if user.name else user.email user_balances.append( UserCostShare( user_id=user.id, user_identifier=user_identifier, items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) ) ) user_balances.sort(key=lambda x: x.user_identifier) return ListCostSummary( list_id=db_list.id, list_name=db_list.name, total_list_cost=db_expense.total_amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), num_participating_users=num_participating_users, equal_share_per_user=equal_share_per_user_for_response, user_balances=user_balances ) async def generate_expense_from_list_logic(db: AsyncSession, list_id: int, current_user_id: int) -> ExpenseModel: """ Generates and saves an ITEM_BASED expense from a list's items. """ await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user_id) # Check if an expense already exists for this list existing_expense_result = await db.execute( select(ExpenseModel).where(ExpenseModel.list_id == list_id) ) if existing_expense_result.scalars().first(): raise InvalidOperationError(f"An expense already exists for list {list_id}.") db_list = await db.get(ListModel, list_id, options=[selectinload(ListModel.items), selectinload(ListModel.creator)]) if not db_list: raise ListNotFoundError(list_id) total_amount = sum(item.price for item in db_list.items if item.price is not None and item.price > Decimal("0")) if total_amount <= Decimal("0"): raise InvalidOperationError("Cannot create an expense for a list with no priced items.") expense_in = ExpenseCreate( description=f"Cost summary for list {db_list.name}", total_amount=total_amount, list_id=list_id, split_type=SplitTypeEnum.ITEM_BASED, paid_by_user_id=db_list.creator.id ) return await crud_expense.create_expense(db=db, expense_in=expense_in, current_user_id=current_user_id) async def get_group_balance_summary_logic( db: AsyncSession, group_id: int, current_user_id: int ) -> GroupBalanceSummary: """ Core logic to retrieve a detailed financial balance summary for a group. """ group_check_result = await db.execute( select(GroupModel).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))) .where(GroupModel.id == group_id) ) db_group = group_check_result.scalars().first() if not db_group: raise GroupNotFoundError(group_id) if not any(assoc.user_id == current_user_id for assoc in db_group.member_associations): raise GroupPermissionError(group_id, "view balance summary for") expenses_result = await db.execute( select(ExpenseModel).where(ExpenseModel.group_id == group_id) .options(selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user)) ) expenses = expenses_result.scalars().all() settlements_result = await db.execute( select(SettlementModel).where(SettlementModel.group_id == group_id) .options(selectinload(SettlementModel.paid_by_user), selectinload(SettlementModel.paid_to_user)) ) settlements = settlements_result.scalars().all() settlement_activities_result = await db.execute( select(SettlementActivityModel) .join(ExpenseSplitModel, SettlementActivityModel.expense_split_id == ExpenseSplitModel.id) .join(ExpenseModel, ExpenseSplitModel.expense_id == ExpenseModel.id) .where(ExpenseModel.group_id == group_id) .options(selectinload(SettlementActivityModel.payer)) ) settlement_activities = settlement_activities_result.scalars().all() user_balances_data = {} for assoc in db_group.member_associations: if assoc.user: user_balances_data[assoc.user.id] = { "user_id": assoc.user.id, "user_identifier": assoc.user.name if assoc.user.name else assoc.user.email, "total_paid_for_expenses": Decimal("0.00"), "initial_total_share_of_expenses": Decimal("0.00"), "total_amount_paid_via_settlement_activities": Decimal("0.00"), "total_generic_settlements_paid": Decimal("0.00"), "total_generic_settlements_received": Decimal("0.00"), } for expense in expenses: if expense.paid_by_user_id in user_balances_data: user_balances_data[expense.paid_by_user_id]["total_paid_for_expenses"] += expense.total_amount for split in expense.splits: if split.user_id in user_balances_data: user_balances_data[split.user_id]["initial_total_share_of_expenses"] += split.owed_amount for activity in settlement_activities: if activity.paid_by_user_id in user_balances_data: user_balances_data[activity.paid_by_user_id]["total_amount_paid_via_settlement_activities"] += activity.amount_paid for settlement in settlements: if settlement.paid_by_user_id in user_balances_data: user_balances_data[settlement.paid_by_user_id]["total_generic_settlements_paid"] += settlement.amount if settlement.paid_to_user_id in user_balances_data: user_balances_data[settlement.paid_to_user_id]["total_generic_settlements_received"] += settlement.amount final_user_balances = [] for user_id, data in user_balances_data.items(): initial_total_share_of_expenses = data["initial_total_share_of_expenses"] total_amount_paid_via_settlement_activities = data["total_amount_paid_via_settlement_activities"] adjusted_total_share_of_expenses = initial_total_share_of_expenses - total_amount_paid_via_settlement_activities total_paid_for_expenses = data["total_paid_for_expenses"] total_generic_settlements_received = data["total_generic_settlements_received"] total_generic_settlements_paid = data["total_generic_settlements_paid"] net_balance = ( total_paid_for_expenses + total_generic_settlements_received ) - (adjusted_total_share_of_expenses + total_generic_settlements_paid) user_detail = UserBalanceDetail( user_id=data["user_id"], user_identifier=data["user_identifier"], total_paid_for_expenses=total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), total_share_of_expenses=adjusted_total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), total_settlements_paid=total_generic_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), total_settlements_received=total_generic_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), net_balance=net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) ) final_user_balances.append(user_detail) final_user_balances.sort(key=lambda x: x.user_identifier) suggested_settlements = calculate_suggested_settlements(final_user_balances) overall_total_expenses = sum(expense.total_amount for expense in expenses) overall_total_settlements = sum(settlement.amount for settlement in settlements) return GroupBalanceSummary( group_id=db_group.id, group_name=db_group.name, overall_total_expenses=overall_total_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), overall_total_settlements=overall_total_settlements.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), user_balances=final_user_balances, suggested_settlements=suggested_settlements )