from typing import List, Optional
from decimal import Decimal
from datetime import datetime, timezone

from sqlalchemy import select, func, update, delete
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload, joinedload

from app.models import (
    SettlementActivity,
    ExpenseSplit,
    Expense,
    User,
    ExpenseSplitStatusEnum,
    ExpenseOverallStatusEnum,
)
# Placeholder for Pydantic schema - actual schema definition is a later step
# from app.schemas.settlement_activity import SettlementActivityCreate  # Assuming this path
from pydantic import BaseModel # Using pydantic BaseModel directly for the placeholder


class SettlementActivityCreatePlaceholder(BaseModel):
    expense_split_id: int
    paid_by_user_id: int
    amount_paid: Decimal
    paid_at: Optional[datetime] = None

    class Config:
        orm_mode = True # Pydantic V1 style orm_mode
        # from_attributes = True # Pydantic V2 style


async def update_expense_split_status(db: AsyncSession, expense_split_id: int) -> Optional[ExpenseSplit]:
    """
    Updates the status of an ExpenseSplit based on its settlement activities.
    Also updates the overall status of the parent Expense.
    """
    # Fetch the ExpenseSplit with its related settlement_activities and the parent expense
    result = await db.execute(
        select(ExpenseSplit)
        .options(
            selectinload(ExpenseSplit.settlement_activities),
            joinedload(ExpenseSplit.expense)  # To get expense_id easily
        )
        .where(ExpenseSplit.id == expense_split_id)
    )
    expense_split = result.scalar_one_or_none()

    if not expense_split:
        # Or raise an exception, depending on desired error handling
        return None

    # Calculate total_paid from all settlement_activities for that split
    total_paid = sum(activity.amount_paid for activity in expense_split.settlement_activities)
    total_paid = Decimal(total_paid).quantize(Decimal("0.01")) # Ensure two decimal places

    # Compare total_paid with ExpenseSplit.owed_amount
    if total_paid >= expense_split.owed_amount:
        expense_split.status = ExpenseSplitStatusEnum.paid
        # Set paid_at to the latest relevant SettlementActivity or current time
        # For simplicity, let's find the latest paid_at from activities, or use now()
        latest_paid_at = None
        if expense_split.settlement_activities:
            latest_paid_at = max(act.paid_at for act in expense_split.settlement_activities if act.paid_at)
        
        expense_split.paid_at = latest_paid_at if latest_paid_at else datetime.now(timezone.utc)
    elif total_paid > 0:
        expense_split.status = ExpenseSplitStatusEnum.partially_paid
        expense_split.paid_at = None  # Clear paid_at if not fully paid
    else: # total_paid == 0
        expense_split.status = ExpenseSplitStatusEnum.unpaid
        expense_split.paid_at = None  # Clear paid_at

    await db.flush()
    await db.refresh(expense_split, attribute_names=['status', 'paid_at', 'expense']) # Refresh to get updated data and related expense

    return expense_split


async def update_expense_overall_status(db: AsyncSession, expense_id: int) -> Optional[Expense]:
    """
    Updates the overall_status of an Expense based on the status of its splits.
    """
    # Fetch the Expense with its related splits
    result = await db.execute(
        select(Expense).options(selectinload(Expense.splits)).where(Expense.id == expense_id)
    )
    expense = result.scalar_one_or_none()

    if not expense:
        # Or raise an exception
        return None

    if not expense.splits: # No splits, should not happen for a valid expense but handle defensively
        expense.overall_settlement_status = ExpenseOverallStatusEnum.unpaid # Or some other default/error state
        await db.flush()
        await db.refresh(expense)
        return expense

    num_splits = len(expense.splits)
    num_paid_splits = 0
    num_partially_paid_splits = 0
    num_unpaid_splits = 0

    for split in expense.splits:
        if split.status == ExpenseSplitStatusEnum.paid:
            num_paid_splits += 1
        elif split.status == ExpenseSplitStatusEnum.partially_paid:
            num_partially_paid_splits += 1
        else: # unpaid
            num_unpaid_splits += 1
    
    if num_paid_splits == num_splits:
        expense.overall_settlement_status = ExpenseOverallStatusEnum.paid
    elif num_unpaid_splits == num_splits:
        expense.overall_settlement_status = ExpenseOverallStatusEnum.unpaid
    else: # Mix of paid, partially_paid, or unpaid but not all unpaid/paid
        expense.overall_settlement_status = ExpenseOverallStatusEnum.partially_paid
        
    await db.flush()
    await db.refresh(expense, attribute_names=['overall_settlement_status'])
    return expense


async def create_settlement_activity(
    db: AsyncSession,
    settlement_activity_in: SettlementActivityCreatePlaceholder,
    current_user_id: int
) -> Optional[SettlementActivity]:
    """
    Creates a new settlement activity, then updates the parent expense split and expense statuses.
    """
    # Validate ExpenseSplit
    split_result = await db.execute(select(ExpenseSplit).where(ExpenseSplit.id == settlement_activity_in.expense_split_id))
    expense_split = split_result.scalar_one_or_none()
    if not expense_split:
        # Consider raising an HTTPException in an API layer
        return None # ExpenseSplit not found

    # Validate User (paid_by_user_id)
    user_result = await db.execute(select(User).where(User.id == settlement_activity_in.paid_by_user_id))
    paid_by_user = user_result.scalar_one_or_none()
    if not paid_by_user:
        return None # User not found

    # Create SettlementActivity instance
    db_settlement_activity = SettlementActivity(
        expense_split_id=settlement_activity_in.expense_split_id,
        paid_by_user_id=settlement_activity_in.paid_by_user_id,
        amount_paid=settlement_activity_in.amount_paid,
        paid_at=settlement_activity_in.paid_at if settlement_activity_in.paid_at else datetime.now(timezone.utc),
        created_by_user_id=current_user_id  # The user recording the activity
    )

    db.add(db_settlement_activity)
    await db.flush() # Flush to get the ID for db_settlement_activity
    
    # Update statuses
    updated_split = await update_expense_split_status(db, expense_split_id=db_settlement_activity.expense_split_id)
    if updated_split and updated_split.expense_id:
        await update_expense_overall_status(db, expense_id=updated_split.expense_id)
    else:
        # This case implies update_expense_split_status returned None or expense_id was missing.
        # This could be a problem, consider logging or raising an error.
        # For now, the transaction would roll back if an exception is raised.
        # If not raising, the overall status update might be skipped.
        pass # Or handle error

    await db.refresh(db_settlement_activity, attribute_names=['split', 'payer', 'creator']) # Refresh to load relationships
    
    return db_settlement_activity


async def get_settlement_activity_by_id(
    db: AsyncSession, settlement_activity_id: int
) -> Optional[SettlementActivity]:
    """
    Fetches a single SettlementActivity by its ID, loading relationships.
    """
    result = await db.execute(
        select(SettlementActivity)
        .options(
            selectinload(SettlementActivity.split).selectinload(ExpenseSplit.expense), # Load split and its parent expense
            selectinload(SettlementActivity.payer), # Load the user who paid
            selectinload(SettlementActivity.creator) # Load the user who created the record
        )
        .where(SettlementActivity.id == settlement_activity_id)
    )
    return result.scalar_one_or_none()


async def get_settlement_activities_for_split(
    db: AsyncSession, expense_split_id: int, skip: int = 0, limit: int = 100
) -> List[SettlementActivity]:
    """
    Fetches a list of SettlementActivity records associated with a given expense_split_id.
    """
    result = await db.execute(
        select(SettlementActivity)
        .where(SettlementActivity.expense_split_id == expense_split_id)
        .options(
            selectinload(SettlementActivity.payer), # Load the user who paid
            selectinload(SettlementActivity.creator) # Load the user who created the record
        )
        .order_by(SettlementActivity.paid_at.desc(), SettlementActivity.created_at.desc())
        .offset(skip)
        .limit(limit)
    )
    return result.scalars().all()

# Further CRUD operations like update/delete can be added later if needed.