266 lines
12 KiB
Python
266 lines
12 KiB
Python
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from sqlalchemy.orm import selectinload, joinedload
|
|
from sqlalchemy import func as sql_func, or_
|
|
from decimal import Decimal, ROUND_HALF_UP
|
|
from typing import Dict, Optional, Sequence, List as PyList
|
|
from collections import defaultdict
|
|
import logging
|
|
|
|
from app.models import (
|
|
List as ListModel,
|
|
Item as ItemModel,
|
|
User as UserModel,
|
|
UserGroup as UserGroupModel,
|
|
Group as GroupModel,
|
|
Expense as ExpenseModel,
|
|
ExpenseSplit as ExpenseSplitModel,
|
|
Settlement as SettlementModel
|
|
)
|
|
from app.schemas.cost import (
|
|
ListCostSummary,
|
|
UserCostShare,
|
|
GroupBalanceSummary,
|
|
UserBalanceDetail,
|
|
SuggestedSettlement
|
|
)
|
|
from app.core.exceptions import (
|
|
ListNotFoundError,
|
|
UserNotFoundError,
|
|
GroupNotFoundError,
|
|
InvalidOperationError
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
async def calculate_list_cost_summary(db: AsyncSession, list_id: int) -> ListCostSummary:
|
|
"""
|
|
Calculates the cost summary for a given list based purely on item prices and who added them.
|
|
This is a simpler calculation and does not involve the Expense/Settlement system.
|
|
"""
|
|
list_result = await db.execute(
|
|
select(ListModel)
|
|
.options(
|
|
selectinload(ListModel.items).options(joinedload(ItemModel.added_by_user)),
|
|
selectinload(ListModel.group).options(selectinload(GroupModel.user_associations).options(selectinload(UserGroupModel.user))),
|
|
selectinload(ListModel.creator)
|
|
)
|
|
.where(ListModel.id == list_id)
|
|
)
|
|
db_list: Optional[ListModel] = list_result.scalars().first()
|
|
|
|
if not db_list:
|
|
raise ListNotFoundError(list_id)
|
|
|
|
participating_users_map: Dict[int, UserModel] = {}
|
|
if db_list.group:
|
|
for ug_assoc in db_list.group.user_associations:
|
|
if ug_assoc.user:
|
|
participating_users_map[ug_assoc.user.id] = ug_assoc.user
|
|
elif db_list.creator: # Personal list
|
|
participating_users_map[db_list.creator.id] = db_list.creator
|
|
|
|
# Include all users who added items with prices, even if not in the primary context (group/creator)
|
|
for item in db_list.items:
|
|
if item.price is not None and item.price > Decimal("0") and item.added_by_user and item.added_by_user.id not in participating_users_map:
|
|
participating_users_map[item.added_by_user.id] = item.added_by_user
|
|
|
|
num_participating_users = len(participating_users_map)
|
|
if num_participating_users == 0:
|
|
return ListCostSummary(
|
|
list_id=db_list.id, list_name=db_list.name, total_list_cost=Decimal("0.00"),
|
|
num_participating_users=0, equal_share_per_user=Decimal("0.00"), user_balances=[]
|
|
)
|
|
|
|
total_list_cost = Decimal("0.00")
|
|
user_items_added_value: Dict[int, Decimal] = defaultdict(Decimal)
|
|
|
|
for item in db_list.items:
|
|
if item.price is not None and item.price > Decimal("0"):
|
|
total_list_cost += item.price
|
|
if item.added_by_id in participating_users_map:
|
|
user_items_added_value[item.added_by_id] += item.price
|
|
|
|
equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
remainder = total_list_cost - (equal_share_per_user * num_participating_users)
|
|
|
|
user_balances: PyList[UserCostShare] = []
|
|
first_user_processed = False
|
|
for user_id, user_obj in participating_users_map.items():
|
|
items_added = user_items_added_value.get(user_id, Decimal("0.00"))
|
|
current_user_share = equal_share_per_user
|
|
if not first_user_processed and remainder != Decimal("0"):
|
|
current_user_share += remainder
|
|
first_user_processed = True
|
|
|
|
balance = items_added - current_user_share
|
|
user_identifier = user_obj.name if user_obj.name else user_obj.email
|
|
user_balances.append(
|
|
UserCostShare(
|
|
user_id=user_id, user_identifier=user_identifier,
|
|
items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
|
amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
|
balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
)
|
|
)
|
|
user_balances.sort(key=lambda x: x.user_identifier)
|
|
return ListCostSummary(
|
|
list_id=db_list.id, list_name=db_list.name,
|
|
total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
|
num_participating_users=num_participating_users,
|
|
equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
|
user_balances=user_balances
|
|
)
|
|
|
|
# --- Helper for Settlement Suggestions ---
|
|
def calculate_suggested_settlements(user_balances: PyList[UserBalanceDetail]) -> PyList[SuggestedSettlement]:
|
|
"""
|
|
Calculates a list of suggested settlements to resolve group debts.
|
|
Uses a greedy algorithm to minimize the number of transactions.
|
|
Input: List of UserBalanceDetail objects with calculated net_balances.
|
|
Output: List of SuggestedSettlement objects.
|
|
"""
|
|
# Use a small tolerance for floating point comparisons with Decimal
|
|
tolerance = Decimal("0.001")
|
|
|
|
debtors = sorted([ub for ub in user_balances if ub.net_balance < -tolerance], key=lambda x: x.net_balance)
|
|
creditors = sorted([ub for ub in user_balances if ub.net_balance > tolerance], key=lambda x: x.net_balance, reverse=True)
|
|
|
|
settlements: PyList[SuggestedSettlement] = []
|
|
debtor_idx = 0
|
|
creditor_idx = 0
|
|
|
|
# Create mutable copies of balances to track remaining amounts
|
|
debtor_balances = {d.user_id: d.net_balance for d in debtors}
|
|
creditor_balances = {c.user_id: c.net_balance for c in creditors}
|
|
user_identifiers = {ub.user_id: ub.user_identifier for ub in user_balances}
|
|
|
|
while debtor_idx < len(debtors) and creditor_idx < len(creditors):
|
|
debtor = debtors[debtor_idx]
|
|
creditor = creditors[creditor_idx]
|
|
|
|
debtor_remaining = debtor_balances[debtor.user_id]
|
|
creditor_remaining = creditor_balances[creditor.user_id]
|
|
|
|
# Amount to transfer is the minimum of what debtor owes and what creditor is owed
|
|
transfer_amount = min(abs(debtor_remaining), creditor_remaining)
|
|
transfer_amount = transfer_amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
|
|
if transfer_amount > tolerance: # Only record meaningful transfers
|
|
settlements.append(SuggestedSettlement(
|
|
from_user_id=debtor.user_id,
|
|
from_user_identifier=user_identifiers.get(debtor.user_id, "Unknown Debtor"),
|
|
to_user_id=creditor.user_id,
|
|
to_user_identifier=user_identifiers.get(creditor.user_id, "Unknown Creditor"),
|
|
amount=transfer_amount
|
|
))
|
|
|
|
# Update remaining balances
|
|
debtor_balances[debtor.user_id] += transfer_amount
|
|
creditor_balances[creditor.user_id] -= transfer_amount
|
|
|
|
# Move to next debtor if current one is settled (or very close)
|
|
if abs(debtor_balances[debtor.user_id]) < tolerance:
|
|
debtor_idx += 1
|
|
|
|
# Move to next creditor if current one is settled (or very close)
|
|
if creditor_balances[creditor.user_id] < tolerance:
|
|
creditor_idx += 1
|
|
|
|
# Log if lists aren't empty - indicates potential imbalance or rounding issue
|
|
if debtor_idx < len(debtors) or creditor_idx < len(creditors):
|
|
# Calculate remaining balances for logging
|
|
remaining_debt = sum(bal for bal in debtor_balances.values() if bal < -tolerance)
|
|
remaining_credit = sum(bal for bal in creditor_balances.values() if bal > tolerance)
|
|
logger.warning(f"Settlement suggestion calculation finished with remaining balances. Debt: {remaining_debt}, Credit: {remaining_credit}. This might be due to minor rounding discrepancies.")
|
|
|
|
return settlements
|
|
|
|
# --- NEW: Detailed Group Balance Summary ---
|
|
async def calculate_group_balance_summary(db: AsyncSession, group_id: int) -> GroupBalanceSummary:
|
|
"""
|
|
Calculates a detailed balance summary for all users in a group,
|
|
considering all expenses, splits, and settlements within that group.
|
|
Also calculates suggested settlements.
|
|
"""
|
|
group = await db.get(GroupModel, group_id)
|
|
if not group:
|
|
raise GroupNotFoundError(group_id)
|
|
|
|
# 1. Get all group members
|
|
group_members_result = await db.execute(
|
|
select(UserModel)
|
|
.join(UserGroupModel, UserModel.id == UserGroupModel.user_id)
|
|
.where(UserGroupModel.group_id == group_id)
|
|
)
|
|
group_members: Dict[int, UserModel] = {user.id: user for user in group_members_result.scalars().all()}
|
|
if not group_members:
|
|
return GroupBalanceSummary(
|
|
group_id=group.id, group_name=group.name, user_balances=[], suggested_settlements=[]
|
|
)
|
|
|
|
user_balances_data: Dict[int, UserBalanceDetail] = {}
|
|
for user_id, user_obj in group_members.items():
|
|
user_balances_data[user_id] = UserBalanceDetail(
|
|
user_id=user_id,
|
|
user_identifier=user_obj.name if user_obj.name else user_obj.email
|
|
)
|
|
|
|
overall_total_expenses = Decimal("0.00")
|
|
overall_total_settlements = Decimal("0.00")
|
|
|
|
# 2. Process Expenses and ExpenseSplits for the group
|
|
expenses_result = await db.execute(
|
|
select(ExpenseModel)
|
|
.where(ExpenseModel.group_id == group_id)
|
|
.options(selectinload(ExpenseModel.splits))
|
|
)
|
|
for expense in expenses_result.scalars().all():
|
|
overall_total_expenses += expense.total_amount
|
|
if expense.paid_by_user_id in user_balances_data:
|
|
user_balances_data[expense.paid_by_user_id].total_paid_for_expenses += expense.total_amount
|
|
|
|
for split in expense.splits:
|
|
if split.user_id in user_balances_data:
|
|
user_balances_data[split.user_id].total_share_of_expenses += split.owed_amount
|
|
|
|
# 3. Process Settlements for the group
|
|
settlements_result = await db.execute(
|
|
select(SettlementModel).where(SettlementModel.group_id == group_id)
|
|
)
|
|
for settlement in settlements_result.scalars().all():
|
|
overall_total_settlements += settlement.amount
|
|
if settlement.paid_by_user_id in user_balances_data:
|
|
user_balances_data[settlement.paid_by_user_id].total_settlements_paid += settlement.amount
|
|
if settlement.paid_to_user_id in user_balances_data:
|
|
user_balances_data[settlement.paid_to_user_id].total_settlements_received += settlement.amount
|
|
|
|
# 4. Calculate net balances and prepare final list
|
|
final_user_balances: PyList[UserBalanceDetail] = []
|
|
for user_id in group_members.keys():
|
|
data = user_balances_data[user_id]
|
|
data.net_balance = (
|
|
data.total_paid_for_expenses + data.total_settlements_received
|
|
) - (data.total_share_of_expenses + data.total_settlements_paid)
|
|
|
|
data.total_paid_for_expenses = data.total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
data.total_share_of_expenses = data.total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
data.total_settlements_paid = data.total_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
data.total_settlements_received = data.total_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
data.net_balance = data.net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
|
|
|
|
final_user_balances.append(data)
|
|
|
|
final_user_balances.sort(key=lambda x: x.user_identifier)
|
|
|
|
# 5. Calculate suggested settlements (NEW)
|
|
suggested_settlements = calculate_suggested_settlements(final_user_balances)
|
|
|
|
return GroupBalanceSummary(
|
|
group_id=group.id,
|
|
group_name=group.name,
|
|
overall_total_expenses=overall_total_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
|
overall_total_settlements=overall_total_settlements.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
|
|
user_balances=final_user_balances,
|
|
suggested_settlements=suggested_settlements # Add suggestions to response
|
|
) |