mitlist/be/app/crud/cost.py
2025-05-08 00:56:26 +02:00

266 lines
12 KiB
Python

from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy import func as sql_func, or_
from decimal import Decimal, ROUND_HALF_UP
from typing import Dict, Optional, Sequence, List as PyList
from collections import defaultdict
import logging
from app.models import (
List as ListModel,
Item as ItemModel,
User as UserModel,
UserGroup as UserGroupModel,
Group as GroupModel,
Expense as ExpenseModel,
ExpenseSplit as ExpenseSplitModel,
Settlement as SettlementModel
)
from app.schemas.cost import (
ListCostSummary,
UserCostShare,
GroupBalanceSummary,
UserBalanceDetail,
SuggestedSettlement
)
from app.core.exceptions import (
ListNotFoundError,
UserNotFoundError,
GroupNotFoundError,
InvalidOperationError
)
logger = logging.getLogger(__name__)
async def calculate_list_cost_summary(db: AsyncSession, list_id: int) -> ListCostSummary:
"""
Calculates the cost summary for a given list based purely on item prices and who added them.
This is a simpler calculation and does not involve the Expense/Settlement system.
"""
list_result = await db.execute(
select(ListModel)
.options(
selectinload(ListModel.items).options(joinedload(ItemModel.added_by_user)),
selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user))),
selectinload(ListModel.creator)
)
.where(ListModel.id == list_id)
)
db_list: Optional[ListModel] = list_result.scalars().first()
if not db_list:
raise ListNotFoundError(list_id)
participating_users_map: Dict[int, UserModel] = {}
if db_list.group:
for ug_assoc in db_list.group.user_associations:
if ug_assoc.user:
participating_users_map[ug_assoc.user.id] = ug_assoc.user
elif db_list.creator: # Personal list
participating_users_map[db_list.creator.id] = db_list.creator
# Include all users who added items with prices, even if not in the primary context (group/creator)
for item in db_list.items:
if item.price is not None and item.price > Decimal("0") and item.added_by_user and item.added_by_user.id not in participating_users_map:
participating_users_map[item.added_by_user.id] = item.added_by_user
num_participating_users = len(participating_users_map)
if num_participating_users == 0:
return ListCostSummary(
list_id=db_list.id, list_name=db_list.name, total_list_cost=Decimal("0.00"),
num_participating_users=0, equal_share_per_user=Decimal("0.00"), user_balances=[]
)
total_list_cost = Decimal("0.00")
user_items_added_value: Dict[int, Decimal] = defaultdict(Decimal)
for item in db_list.items:
if item.price is not None and item.price > Decimal("0"):
total_list_cost += item.price
if item.added_by_id in participating_users_map:
user_items_added_value[item.added_by_id] += item.price
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
remainder = total_list_cost - (equal_share_per_user * num_participating_users)
user_balances: PyList[UserCostShare] = []
first_user_processed = False
for user_id, user_obj in participating_users_map.items():
items_added = user_items_added_value.get(user_id, Decimal("0.00"))
current_user_share = equal_share_per_user
if not first_user_processed and remainder != Decimal("0"):
current_user_share += remainder
first_user_processed = True
balance = items_added - current_user_share
user_identifier = user_obj.name if user_obj.name else user_obj.email
user_balances.append(
UserCostShare(
user_id=user_id, user_identifier=user_identifier,
items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
)
)
user_balances.sort(key=lambda x: x.user_identifier)
return ListCostSummary(
list_id=db_list.id, list_name=db_list.name,
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
num_participating_users=num_participating_users,
equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
user_balances=user_balances
)
# --- Helper for Settlement Suggestions ---
def calculate_suggested_settlements(user_balances: PyList[UserBalanceDetail]) -> PyList[SuggestedSettlement]:
"""
Calculates a list of suggested settlements to resolve group debts.
Uses a greedy algorithm to minimize the number of transactions.
Input: List of UserBalanceDetail objects with calculated net_balances.
Output: List of SuggestedSettlement objects.
"""
# Use a small tolerance for floating point comparisons with Decimal
tolerance = Decimal("0.001")
debtors = sorted([ub for ub in user_balances if ub.net_balance < -tolerance], key=lambda x: x.net_balance)
creditors = sorted([ub for ub in user_balances if ub.net_balance > tolerance], key=lambda x: x.net_balance, reverse=True)
settlements: PyList[SuggestedSettlement] = []
debtor_idx = 0
creditor_idx = 0
# Create mutable copies of balances to track remaining amounts
debtor_balances = {d.user_id: d.net_balance for d in debtors}
creditor_balances = {c.user_id: c.net_balance for c in creditors}
user_identifiers = {ub.user_id: ub.user_identifier for ub in user_balances}
while debtor_idx < len(debtors) and creditor_idx < len(creditors):
debtor = debtors[debtor_idx]
creditor = creditors[creditor_idx]
debtor_remaining = debtor_balances[debtor.user_id]
creditor_remaining = creditor_balances[creditor.user_id]
# Amount to transfer is the minimum of what debtor owes and what creditor is owed
transfer_amount = min(abs(debtor_remaining), creditor_remaining)
transfer_amount = transfer_amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
if transfer_amount > tolerance: # Only record meaningful transfers
settlements.append(SuggestedSettlement(
from_user_id=debtor.user_id,
from_user_identifier=user_identifiers.get(debtor.user_id, "Unknown Debtor"),
to_user_id=creditor.user_id,
to_user_identifier=user_identifiers.get(creditor.user_id, "Unknown Creditor"),
amount=transfer_amount
))
# Update remaining balances
debtor_balances[debtor.user_id] += transfer_amount
creditor_balances[creditor.user_id] -= transfer_amount
# Move to next debtor if current one is settled (or very close)
if abs(debtor_balances[debtor.user_id]) < tolerance:
debtor_idx += 1
# Move to next creditor if current one is settled (or very close)
if creditor_balances[creditor.user_id] < tolerance:
creditor_idx += 1
# Log if lists aren't empty - indicates potential imbalance or rounding issue
if debtor_idx < len(debtors) or creditor_idx < len(creditors):
# Calculate remaining balances for logging
remaining_debt = sum(bal for bal in debtor_balances.values() if bal < -tolerance)
remaining_credit = sum(bal for bal in creditor_balances.values() if bal > tolerance)
logger.warning(f"Settlement suggestion calculation finished with remaining balances. Debt: {remaining_debt}, Credit: {remaining_credit}. This might be due to minor rounding discrepancies.")
return settlements
# --- NEW: Detailed Group Balance Summary ---
async def calculate_group_balance_summary(db: AsyncSession, group_id: int) -> GroupBalanceSummary:
"""
Calculates a detailed balance summary for all users in a group,
considering all expenses, splits, and settlements within that group.
Also calculates suggested settlements.
"""
group = await db.get(GroupModel, group_id)
if not group:
raise GroupNotFoundError(group_id)
# 1. Get all group members
group_members_result = await db.execute(
select(UserModel)
.join(UserGroupModel, UserModel.id == UserGroupModel.user_id)
.where(UserGroupModel.group_id == group_id)
)
group_members: Dict[int, UserModel] = {user.id: user for user in group_members_result.scalars().all()}
if not group_members:
return GroupBalanceSummary(
group_id=group.id, group_name=group.name, user_balances=[], suggested_settlements=[]
)
user_balances_data: Dict[int, UserBalanceDetail] = {}
for user_id, user_obj in group_members.items():
user_balances_data[user_id] = UserBalanceDetail(
user_id=user_id,
user_identifier=user_obj.name if user_obj.name else user_obj.email
)
overall_total_expenses = Decimal("0.00")
overall_total_settlements = Decimal("0.00")
# 2. Process Expenses and ExpenseSplits for the group
expenses_result = await db.execute(
select(ExpenseModel)
.where(ExpenseModel.group_id == group_id)
.options(selectinload(ExpenseModel.splits))
)
for expense in expenses_result.scalars().all():
overall_total_expenses += expense.total_amount
if expense.paid_by_user_id in user_balances_data:
user_balances_data[expense.paid_by_user_id].total_paid_for_expenses += expense.total_amount
for split in expense.splits:
if split.user_id in user_balances_data:
user_balances_data[split.user_id].total_share_of_expenses += split.owed_amount
# 3. Process Settlements for the group
settlements_result = await db.execute(
select(SettlementModel).where(SettlementModel.group_id == group_id)
)
for settlement in settlements_result.scalars().all():
overall_total_settlements += settlement.amount
if settlement.paid_by_user_id in user_balances_data:
user_balances_data[settlement.paid_by_user_id].total_settlements_paid += settlement.amount
if settlement.paid_to_user_id in user_balances_data:
user_balances_data[settlement.paid_to_user_id].total_settlements_received += settlement.amount
# 4. Calculate net balances and prepare final list
final_user_balances: PyList[UserBalanceDetail] = []
for user_id in group_members.keys():
data = user_balances_data[user_id]
data.net_balance = (
data.total_paid_for_expenses + data.total_settlements_received
) - (data.total_share_of_expenses + data.total_settlements_paid)
data.total_paid_for_expenses = data.total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_share_of_expenses = data.total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_settlements_paid = data.total_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.total_settlements_received = data.total_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
data.net_balance = data.net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
final_user_balances.append(data)
final_user_balances.sort(key=lambda x: x.user_identifier)
# 5. Calculate suggested settlements (NEW)
suggested_settlements = calculate_suggested_settlements(final_user_balances)
return GroupBalanceSummary(
group_id=group.id,
group_name=group.name,
overall_total_expenses=overall_total_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
overall_total_settlements=overall_total_settlements.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
user_balances=final_user_balances,
suggested_settlements=suggested_settlements # Add suggestions to response
)