from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload, joinedload from sqlalchemy import func as sql_func, or_ from decimal import Decimal, ROUND_HALF_UP from typing import Dict, Optional, Sequence, List as PyList from collections import defaultdict import logging from app.models import ( List as ListModel, Item as ItemModel, User as UserModel, UserGroup as UserGroupModel, Group as GroupModel, Expense as ExpenseModel, ExpenseSplit as ExpenseSplitModel, Settlement as SettlementModel ) from app.schemas.cost import ( ListCostSummary, UserCostShare, GroupBalanceSummary, UserBalanceDetail, SuggestedSettlement ) from app.core.exceptions import ( ListNotFoundError, UserNotFoundError, GroupNotFoundError, InvalidOperationError ) logger = logging.getLogger(__name__) async def calculate_list_cost_summary(db: AsyncSession, list_id: int) -> ListCostSummary: """ Calculates the cost summary for a given list based purely on item prices and who added them. This is a simpler calculation and does not involve the Expense/Settlement system. """ list_result = await db.execute( select(ListModel) .options( selectinload(ListModel.items).options(joinedload(ItemModel.added_by_user)), selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user))), selectinload(ListModel.creator) ) .where(ListModel.id == list_id) ) db_list: Optional[ListModel] = list_result.scalars().first() if not db_list: raise ListNotFoundError(list_id) participating_users_map: Dict[int, UserModel] = {} if db_list.group: for ug_assoc in db_list.group.user_associations: if ug_assoc.user: participating_users_map[ug_assoc.user.id] = ug_assoc.user elif db_list.creator: # Personal list participating_users_map[db_list.creator.id] = db_list.creator # Include all users who added items with prices, even if not in the primary context (group/creator) for item in db_list.items: if item.price is not None and item.price > Decimal("0") and item.added_by_user and item.added_by_user.id not in participating_users_map: participating_users_map[item.added_by_user.id] = item.added_by_user num_participating_users = len(participating_users_map) if num_participating_users == 0: return ListCostSummary( list_id=db_list.id, list_name=db_list.name, total_list_cost=Decimal("0.00"), num_participating_users=0, equal_share_per_user=Decimal("0.00"), user_balances=[] ) total_list_cost = Decimal("0.00") user_items_added_value: Dict[int, Decimal] = defaultdict(Decimal) for item in db_list.items: if item.price is not None and item.price > Decimal("0"): total_list_cost += item.price if item.added_by_id in participating_users_map: user_items_added_value[item.added_by_id] += item.price equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) remainder = total_list_cost - (equal_share_per_user * num_participating_users) user_balances: PyList[UserCostShare] = [] first_user_processed = False for user_id, user_obj in participating_users_map.items(): items_added = user_items_added_value.get(user_id, Decimal("0.00")) current_user_share = equal_share_per_user if not first_user_processed and remainder != Decimal("0"): current_user_share += remainder first_user_processed = True balance = items_added - current_user_share user_identifier = user_obj.name if user_obj.name else user_obj.email user_balances.append( UserCostShare( user_id=user_id, user_identifier=user_identifier, items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) ) ) user_balances.sort(key=lambda x: x.user_identifier) return ListCostSummary( list_id=db_list.id, list_name=db_list.name, total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), num_participating_users=num_participating_users, equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), user_balances=user_balances ) # --- Helper for Settlement Suggestions --- def calculate_suggested_settlements(user_balances: PyList[UserBalanceDetail]) -> PyList[SuggestedSettlement]: """ Calculates a list of suggested settlements to resolve group debts. Uses a greedy algorithm to minimize the number of transactions. Input: List of UserBalanceDetail objects with calculated net_balances. Output: List of SuggestedSettlement objects. """ # Use a small tolerance for floating point comparisons with Decimal tolerance = Decimal("0.001") debtors = sorted([ub for ub in user_balances if ub.net_balance < -tolerance], key=lambda x: x.net_balance) creditors = sorted([ub for ub in user_balances if ub.net_balance > tolerance], key=lambda x: x.net_balance, reverse=True) settlements: PyList[SuggestedSettlement] = [] debtor_idx = 0 creditor_idx = 0 # Create mutable copies of balances to track remaining amounts debtor_balances = {d.user_id: d.net_balance for d in debtors} creditor_balances = {c.user_id: c.net_balance for c in creditors} user_identifiers = {ub.user_id: ub.user_identifier for ub in user_balances} while debtor_idx < len(debtors) and creditor_idx < len(creditors): debtor = debtors[debtor_idx] creditor = creditors[creditor_idx] debtor_remaining = debtor_balances[debtor.user_id] creditor_remaining = creditor_balances[creditor.user_id] # Amount to transfer is the minimum of what debtor owes and what creditor is owed transfer_amount = min(abs(debtor_remaining), creditor_remaining) transfer_amount = transfer_amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) if transfer_amount > tolerance: # Only record meaningful transfers settlements.append(SuggestedSettlement( from_user_id=debtor.user_id, from_user_identifier=user_identifiers.get(debtor.user_id, "Unknown Debtor"), to_user_id=creditor.user_id, to_user_identifier=user_identifiers.get(creditor.user_id, "Unknown Creditor"), amount=transfer_amount )) # Update remaining balances debtor_balances[debtor.user_id] += transfer_amount creditor_balances[creditor.user_id] -= transfer_amount # Move to next debtor if current one is settled (or very close) if abs(debtor_balances[debtor.user_id]) < tolerance: debtor_idx += 1 # Move to next creditor if current one is settled (or very close) if creditor_balances[creditor.user_id] < tolerance: creditor_idx += 1 # Log if lists aren't empty - indicates potential imbalance or rounding issue if debtor_idx < len(debtors) or creditor_idx < len(creditors): # Calculate remaining balances for logging remaining_debt = sum(bal for bal in debtor_balances.values() if bal < -tolerance) remaining_credit = sum(bal for bal in creditor_balances.values() if bal > tolerance) logger.warning(f"Settlement suggestion calculation finished with remaining balances. Debt: {remaining_debt}, Credit: {remaining_credit}. This might be due to minor rounding discrepancies.") return settlements # --- NEW: Detailed Group Balance Summary --- async def calculate_group_balance_summary(db: AsyncSession, group_id: int) -> GroupBalanceSummary: """ Calculates a detailed balance summary for all users in a group, considering all expenses, splits, and settlements within that group. Also calculates suggested settlements. """ group = await db.get(GroupModel, group_id) if not group: raise GroupNotFoundError(group_id) # 1. Get all group members group_members_result = await db.execute( select(UserModel) .join(UserGroupModel, UserModel.id == UserGroupModel.user_id) .where(UserGroupModel.group_id == group_id) ) group_members: Dict[int, UserModel] = {user.id: user for user in group_members_result.scalars().all()} if not group_members: return GroupBalanceSummary( group_id=group.id, group_name=group.name, user_balances=[], suggested_settlements=[] ) user_balances_data: Dict[int, UserBalanceDetail] = {} for user_id, user_obj in group_members.items(): user_balances_data[user_id] = UserBalanceDetail( user_id=user_id, user_identifier=user_obj.name if user_obj.name else user_obj.email ) overall_total_expenses = Decimal("0.00") overall_total_settlements = Decimal("0.00") # 2. Process Expenses and ExpenseSplits for the group expenses_result = await db.execute( select(ExpenseModel) .where(ExpenseModel.group_id == group_id) .options(selectinload(ExpenseModel.splits)) ) for expense in expenses_result.scalars().all(): overall_total_expenses += expense.total_amount if expense.paid_by_user_id in user_balances_data: user_balances_data[expense.paid_by_user_id].total_paid_for_expenses += expense.total_amount for split in expense.splits: if split.user_id in user_balances_data: user_balances_data[split.user_id].total_share_of_expenses += split.owed_amount # 3. Process Settlements for the group settlements_result = await db.execute( select(SettlementModel).where(SettlementModel.group_id == group_id) ) for settlement in settlements_result.scalars().all(): overall_total_settlements += settlement.amount if settlement.paid_by_user_id in user_balances_data: user_balances_data[settlement.paid_by_user_id].total_settlements_paid += settlement.amount if settlement.paid_to_user_id in user_balances_data: user_balances_data[settlement.paid_to_user_id].total_settlements_received += settlement.amount # 4. Calculate net balances and prepare final list final_user_balances: PyList[UserBalanceDetail] = [] for user_id in group_members.keys(): data = user_balances_data[user_id] data.net_balance = ( data.total_paid_for_expenses + data.total_settlements_received ) - (data.total_share_of_expenses + data.total_settlements_paid) data.total_paid_for_expenses = data.total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) data.total_share_of_expenses = data.total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) data.total_settlements_paid = data.total_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) data.total_settlements_received = data.total_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) data.net_balance = data.net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP) final_user_balances.append(data) final_user_balances.sort(key=lambda x: x.user_identifier) # 5. Calculate suggested settlements (NEW) suggested_settlements = calculate_suggested_settlements(final_user_balances) return GroupBalanceSummary( group_id=group.id, group_name=group.name, overall_total_expenses=overall_total_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), overall_total_settlements=overall_total_settlements.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP), user_balances=final_user_balances, suggested_settlements=suggested_settlements # Add suggestions to response )