# app/api/v1/endpoints/costs.py
import logging
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, selectinload
from decimal import Decimal, ROUND_HALF_UP

from app.database import get_db
from app.api.dependencies import get_current_user
from app.models import (
    User as UserModel,
    Group as GroupModel,
    List as ListModel,
    Expense as ExpenseModel,
    Item as ItemModel,
    UserGroup as UserGroupModel,
    SplitTypeEnum,
    ExpenseSplit as ExpenseSplitModel,
    Settlement as SettlementModel
)
from app.schemas.cost import ListCostSummary, GroupBalanceSummary
from app.schemas.expense import ExpenseCreate
from app.crud import list as crud_list
from app.crud import expense as crud_expense
from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError, GroupNotFoundError

logger = logging.getLogger(__name__)
router = APIRouter()

@router.get(
    "/lists/{list_id}/cost-summary",
    response_model=ListCostSummary,
    summary="Get Cost Summary for a List",
    tags=["Costs"],
    responses={
        status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this list"},
        status.HTTP_404_NOT_FOUND: {"description": "List or associated user not found"}
    }
)
async def get_list_cost_summary(
    list_id: int,
    db: AsyncSession = Depends(get_db),
    current_user: UserModel = Depends(get_current_user),
):
    """
    Retrieves a calculated cost summary for a specific list, detailing total costs,
    equal shares per user, and individual user balances based on their contributions.

    The user must have access to the list to view its cost summary.
    Costs are split among group members if the list belongs to a group, or just for
    the creator if it's a personal list. All users who added items with prices are
    included in the calculation.
    """
    logger.info(f"User {current_user.email} requesting cost summary for list {list_id}")

    # 1. Verify user has access to the target list
    try:
        await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id)
    except ListPermissionError as e:
        logger.warning(f"Permission denied for user {current_user.email} on list {list_id}: {str(e)}")
        raise
    except ListNotFoundError as e:
        logger.warning(f"List {list_id} not found when checking permissions for cost summary: {str(e)}")
        raise

    # 2. Get the list with its items and users
    list_result = await db.execute(
        select(ListModel)
        .options(
            selectinload(ListModel.items).options(selectinload(ItemModel.added_by_user)),
            selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))),
            selectinload(ListModel.creator)
        )
        .where(ListModel.id == list_id)
    )
    db_list = list_result.scalars().first()
    if not db_list:
        raise ListNotFoundError(list_id)

    # 3. Get or create an expense for this list
    expense_result = await db.execute(
        select(ExpenseModel)
        .where(ExpenseModel.list_id == list_id)
        .options(selectinload(ExpenseModel.splits))
    )
    db_expense = expense_result.scalars().first()

    if not db_expense:
        # Create a new expense for this list
        total_amount = sum(item.price for item in db_list.items if item.price is not None and item.price > Decimal("0"))
        if total_amount == Decimal("0"):
            return ListCostSummary(
                list_id=db_list.id,
                list_name=db_list.name,
                total_list_cost=Decimal("0.00"),
                num_participating_users=0,
                equal_share_per_user=Decimal("0.00"),
                user_balances=[]
            )

        # Create expense with ITEM_BASED split type
        expense_in = ExpenseCreate(
            description=f"Cost summary for list {db_list.name}",
            total_amount=total_amount,
            list_id=list_id,
            split_type=SplitTypeEnum.ITEM_BASED,
            paid_by_user_id=current_user.id  # Use current user as payer for now
        )
        db_expense = await crud_expense.create_expense(db=db, expense_in=expense_in)

    # 4. Calculate cost summary from expense splits
    participating_users = set()
    user_items_added_value = {}
    total_list_cost = Decimal("0.00")

    # Get all users who added items
    for item in db_list.items:
        if item.price is not None and item.price > Decimal("0") and item.added_by_user:
            participating_users.add(item.added_by_user)
            user_items_added_value[item.added_by_user.id] = user_items_added_value.get(item.added_by_user.id, Decimal("0.00")) + item.price
            total_list_cost += item.price

    # Get all users from expense splits
    for split in db_expense.splits:
        if split.user:
            participating_users.add(split.user)

    num_participating_users = len(participating_users)
    if num_participating_users == 0:
        return ListCostSummary(
            list_id=db_list.id,
            list_name=db_list.name,
            total_list_cost=Decimal("0.00"),
            num_participating_users=0,
            equal_share_per_user=Decimal("0.00"),
            user_balances=[]
        )

    equal_share_per_user = (total_list_cost / Decimal(num_participating_users)).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
    remainder = total_list_cost - (equal_share_per_user * num_participating_users)

    user_balances = []
    first_user_processed = False
    for user in participating_users:
        items_added = user_items_added_value.get(user.id, Decimal("0.00"))
        current_user_share = equal_share_per_user
        if not first_user_processed and remainder != Decimal("0"):
            current_user_share += remainder
            first_user_processed = True
        
        balance = items_added - current_user_share
        user_identifier = user.name if user.name else user.email
        user_balances.append(
            UserCostShare(
                user_id=user.id,
                user_identifier=user_identifier,
                items_added_value=items_added.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
                amount_due=current_user_share.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
                balance=balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
            )
        )

    user_balances.sort(key=lambda x: x.user_identifier)
    return ListCostSummary(
        list_id=db_list.id,
        list_name=db_list.name,
        total_list_cost=total_list_cost.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
        num_participating_users=num_participating_users,
        equal_share_per_user=equal_share_per_user.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP),
        user_balances=user_balances
    )

@router.get(
    "/groups/{group_id}/balance-summary",
    response_model=GroupBalanceSummary,
    summary="Get Detailed Balance Summary for a Group",
    tags=["Costs", "Groups"],
    responses={
        status.HTTP_403_FORBIDDEN: {"description": "User does not have permission to access this group"},
        status.HTTP_404_NOT_FOUND: {"description": "Group not found"}
    }
)
async def get_group_balance_summary(
    group_id: int,
    db: AsyncSession = Depends(get_db),
    current_user: UserModel = Depends(get_current_user),
):
    """
    Retrieves a detailed financial balance summary for all users within a specific group.
    It considers all expenses, their splits, and all settlements recorded for the group.
    The user must be a member of the group to view its balance summary.
    """
    logger.info(f"User {current_user.email} requesting balance summary for group {group_id}")

    # 1. Verify user is a member of the target group
    group_check = await db.execute(
        select(GroupModel)
        .options(selectinload(GroupModel.member_associations))
        .where(GroupModel.id == group_id)
    )
    db_group_for_check = group_check.scalars().first()

    if not db_group_for_check:
        raise GroupNotFoundError(group_id)
    
    user_is_member = any(assoc.user_id == current_user.id for assoc in db_group_for_check.member_associations)
    if not user_is_member:
        raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"User not a member of group {group_id}")

    # 2. Get all expenses and settlements for the group
    expenses_result = await db.execute(
        select(ExpenseModel)
        .where(ExpenseModel.group_id == group_id)
        .options(selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user))
    )
    expenses = expenses_result.scalars().all()

    settlements_result = await db.execute(
        select(SettlementModel)
        .where(SettlementModel.group_id == group_id)
        .options(
            selectinload(SettlementModel.paid_by_user),
            selectinload(SettlementModel.paid_to_user)
        )
    )
    settlements = settlements_result.scalars().all()

    # 3. Calculate user balances
    user_balances_data = {}
    for assoc in db_group_for_check.member_associations:
        if assoc.user:
            user_balances_data[assoc.user.id] = UserBalanceDetail(
                user_id=assoc.user.id,
                user_identifier=assoc.user.name if assoc.user.name else assoc.user.email
            )

    # Process expenses
    for expense in expenses:
        if expense.paid_by_user_id in user_balances_data:
            user_balances_data[expense.paid_by_user_id].total_paid_for_expenses += expense.total_amount
        
        for split in expense.splits:
            if split.user_id in user_balances_data:
                user_balances_data[split.user_id].total_share_of_expenses += split.owed_amount

    # Process settlements
    for settlement in settlements:
        if settlement.paid_by_user_id in user_balances_data:
            user_balances_data[settlement.paid_by_user_id].total_settlements_paid += settlement.amount
        if settlement.paid_to_user_id in user_balances_data:
            user_balances_data[settlement.paid_to_user_id].total_settlements_received += settlement.amount

    # Calculate net balances
    final_user_balances = []
    for user_id, data in user_balances_data.items():
        data.net_balance = (
            data.total_paid_for_expenses + data.total_settlements_received
        ) - (data.total_share_of_expenses + data.total_settlements_paid)
        
        data.total_paid_for_expenses = data.total_paid_for_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
        data.total_share_of_expenses = data.total_share_of_expenses.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
        data.total_settlements_paid = data.total_settlements_paid.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
        data.total_settlements_received = data.total_settlements_received.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
        data.net_balance = data.net_balance.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
        
        final_user_balances.append(data)

    # Sort by user identifier
    final_user_balances.sort(key=lambda x: x.user_identifier)

    # Calculate suggested settlements
    suggested_settlements = calculate_suggested_settlements(final_user_balances)

    return GroupBalanceSummary(
        group_id=db_group_for_check.id,
        group_name=db_group_for_check.name,
        user_balances=final_user_balances,
        suggested_settlements=suggested_settlements
    )