# app/crud/expense.py
import logging # Add logging import
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict
from datetime import datetime, timezone # Added timezone

from app.models import (
    Expense as ExpenseModel,
    ExpenseSplit as ExpenseSplitModel,
    User as UserModel,
    List as ListModel,
    Group as GroupModel,
    UserGroup as UserGroupModel,
    SplitTypeEnum,
    Item as ItemModel
)
from app.schemas.expense import ExpenseCreate, ExpenseSplitCreate, ExpenseUpdate # Removed unused ExpenseUpdate
from app.core.exceptions import (
    # Using existing specific exceptions where possible
    ListNotFoundError,
    GroupNotFoundError,
    UserNotFoundError,
    InvalidOperationError # Import the new exception
)

# Placeholder for InvalidOperationError if not defined in app.core.exceptions
# This should be a proper HTTPException subclass if used in API layer
# class CrudInvalidOperationError(ValueError): # For internal CRUD validation logic
#     pass

logger = logging.getLogger(__name__) # Initialize logger

async def get_users_for_splitting(db: AsyncSession, expense_group_id: Optional[int], expense_list_id: Optional[int], expense_paid_by_user_id: int) -> PyList[UserModel]:
    """
    Determines the list of users an expense should be split amongst.
    Priority: Group members (if group_id), then List's group members or creator (if list_id).
    Fallback to only the payer if no other context yields users.
    """
    users_to_split_with: PyList[UserModel] = []
    processed_user_ids = set()

    async def _add_user(user: Optional[UserModel]):
        if user and user.id not in processed_user_ids:
            users_to_split_with.append(user)
            processed_user_ids.add(user.id)

    if expense_group_id:
        group_result = await db.execute(
            select(GroupModel).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user)))
            .where(GroupModel.id == expense_group_id)
        )
        group = group_result.scalars().first()
        if not group:
            raise GroupNotFoundError(expense_group_id)
        for assoc in group.member_associations:
            await _add_user(assoc.user)
    
    elif expense_list_id: # Only if group_id was not primary context
        list_result = await db.execute(
            select(ListModel)
            .options(
                selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))),
                selectinload(ListModel.creator) 
            )
            .where(ListModel.id == expense_list_id)
        )
        db_list = list_result.scalars().first()
        if not db_list:
            raise ListNotFoundError(expense_list_id)
        
        if db_list.group: 
            for assoc in db_list.group.member_associations:
                await _add_user(assoc.user)
        elif db_list.creator: 
            await _add_user(db_list.creator)
    
    if not users_to_split_with: 
        payer_user = await db.get(UserModel, expense_paid_by_user_id)
        if not payer_user:
            # This should have been caught earlier if paid_by_user_id was validated before calling this helper
            raise UserNotFoundError(user_id=expense_paid_by_user_id) 
        await _add_user(payer_user)

    if not users_to_split_with:
        # This should ideally not be reached if payer is always a fallback
        raise InvalidOperationError("Could not determine any users for splitting the expense.")
        
    return users_to_split_with


async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_user_id: int) -> ExpenseModel:
    """Creates a new expense and its associated splits.
    
    Args:
        db: Database session
        expense_in: Expense creation data
        current_user_id: ID of the user creating the expense
        
    Returns:
        The created expense with splits
        
    Raises:
        UserNotFoundError: If payer or split users don't exist
        ListNotFoundError: If specified list doesn't exist
        GroupNotFoundError: If specified group doesn't exist
        InvalidOperationError: For various validation failures
    """
    # Helper function to round decimals consistently
    def round_money(amount: Decimal) -> Decimal:
        return amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
    
    # 1. Context Validation
    # Validate basic context requirements first
    if not expense_in.list_id and not expense_in.group_id:
        raise InvalidOperationError("Expense must be associated with a list or a group.")

    # 2. User Validation
    payer = await db.get(UserModel, expense_in.paid_by_user_id)
    if not payer:
        raise UserNotFoundError(user_id=expense_in.paid_by_user_id)
    
    # 3. List/Group Context Resolution
    final_group_id = await _resolve_expense_context(db, expense_in)
    
    # 4. Create the expense object
    db_expense = ExpenseModel(
        description=expense_in.description,
        total_amount=round_money(expense_in.total_amount),
        currency=expense_in.currency or "USD",
        expense_date=expense_in.expense_date or datetime.now(timezone.utc),
        split_type=expense_in.split_type,
        list_id=expense_in.list_id,
        group_id=final_group_id,
        item_id=expense_in.item_id,
        paid_by_user_id=expense_in.paid_by_user_id,
        created_by_user_id=current_user_id  # Track who created this expense
    )
    
    # 5. Generate splits based on split type
    splits_to_create = await _generate_expense_splits(db, db_expense, expense_in, round_money)
    
    # 6. Single transaction for expense and all splits
    try:
        db.add(db_expense)
        await db.flush()  # Get expense ID without committing
        
        # Update all splits with the expense ID
        for split in splits_to_create:
            split.expense_id = db_expense.id
        
        db.add_all(splits_to_create)
        await db.commit()
        
    except Exception as e:
        await db.rollback()
        logger.error(f"Failed to save expense: {str(e)}", exc_info=True)
        raise InvalidOperationError(f"Failed to save expense: {str(e)}")
    
    # Refresh to get the splits relationship populated
    await db.refresh(db_expense, attribute_names=["splits"])
    return db_expense


async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) -> Optional[int]:
    """Resolves and validates the expense's context (list and group).
    
    Returns the final group_id for the expense after validation.
    """
    final_group_id = expense_in.group_id
    
    # If list_id is provided, validate it and potentially derive group_id
    if expense_in.list_id:
        list_obj = await db.get(ListModel, expense_in.list_id)
        if not list_obj:
            raise ListNotFoundError(expense_in.list_id)
            
        # If list belongs to a group, verify consistency or inherit group_id
        if list_obj.group_id:
            if expense_in.group_id and list_obj.group_id != expense_in.group_id:
                raise InvalidOperationError(
                    f"List {expense_in.list_id} belongs to group {list_obj.group_id}, "
                    f"but expense was specified for group {expense_in.group_id}."
                )
            final_group_id = list_obj.group_id  # Prioritize list's group
    
    # If only group_id is provided (no list_id), validate group_id
    elif final_group_id:
        group_obj = await db.get(GroupModel, final_group_id)
        if not group_obj:
            raise GroupNotFoundError(final_group_id)
            
    return final_group_id


async def _generate_expense_splits(
    db: AsyncSession,
    db_expense: ExpenseModel,
    expense_in: ExpenseCreate,
    round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
    """Generates appropriate expense splits based on split type."""
    
    splits_to_create: PyList[ExpenseSplitModel] = []
    
    # Create splits based on the split type
    if expense_in.split_type == SplitTypeEnum.EQUAL:
        splits_to_create = await _create_equal_splits(
            db, db_expense, expense_in, round_money
        )
        
    elif expense_in.split_type == SplitTypeEnum.EXACT_AMOUNTS:
        splits_to_create = await _create_exact_amount_splits(
            db, db_expense, expense_in, round_money
        )
        
    elif expense_in.split_type == SplitTypeEnum.PERCENTAGE:
        splits_to_create = await _create_percentage_splits(
            db, db_expense, expense_in, round_money
        )
        
    elif expense_in.split_type == SplitTypeEnum.SHARES:
        splits_to_create = await _create_shares_splits(
            db, db_expense, expense_in, round_money
        )
        
    elif expense_in.split_type == SplitTypeEnum.ITEM_BASED:
        splits_to_create = await _create_item_based_splits(
            db, db_expense, expense_in, round_money
        )
        
    else:
        raise InvalidOperationError(f"Unsupported split type: {expense_in.split_type.value}")
    
    if not splits_to_create:
        raise InvalidOperationError("No expense splits were generated.")
        
    return splits_to_create


async def _create_equal_splits(
    db: AsyncSession,
    db_expense: ExpenseModel,
    expense_in: ExpenseCreate,
    round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
    """Creates equal splits among users."""
    
    users_for_splitting = await get_users_for_splitting(
        db, db_expense.group_id, expense_in.list_id, expense_in.paid_by_user_id
    )
    if not users_for_splitting:
        raise InvalidOperationError("No users found for EQUAL split.")
    
    num_users = len(users_for_splitting)
    amount_per_user = round_money(db_expense.total_amount / Decimal(num_users))
    remainder = db_expense.total_amount - (amount_per_user * num_users)
    
    splits = []
    for i, user in enumerate(users_for_splitting):
        split_amount = amount_per_user
        if i == 0 and remainder != Decimal('0'):
            split_amount = round_money(amount_per_user + remainder)
        
        splits.append(ExpenseSplitModel(
            user_id=user.id,
            owed_amount=split_amount
        ))
    
    return splits


async def _create_exact_amount_splits(
    db: AsyncSession,
    db_expense: ExpenseModel,
    expense_in: ExpenseCreate,
    round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
    """Creates splits with exact amounts."""
    
    if not expense_in.splits_in:
        raise InvalidOperationError("Splits data is required for EXACT_AMOUNTS split type.")
    
    # Validate all users in splits exist
    await _validate_users_in_splits(db, expense_in.splits_in)
    
    current_total = Decimal("0.00")
    splits = []
    
    for split_in in expense_in.splits_in:
        if split_in.owed_amount <= Decimal('0'):
            raise InvalidOperationError(f"Owed amount for user {split_in.user_id} must be positive.")
            
        rounded_amount = round_money(split_in.owed_amount)
        current_total += rounded_amount
        
        splits.append(ExpenseSplitModel(
            user_id=split_in.user_id,
            owed_amount=rounded_amount
        ))
    
    if round_money(current_total) != db_expense.total_amount:
        raise InvalidOperationError(
            f"Sum of exact split amounts ({current_total}) != expense total ({db_expense.total_amount})."
        )
    
    return splits


async def _create_percentage_splits(
    db: AsyncSession,
    db_expense: ExpenseModel,
    expense_in: ExpenseCreate,
    round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
    """Creates splits based on percentages."""
    
    if not expense_in.splits_in:
        raise InvalidOperationError("Splits data is required for PERCENTAGE split type.")
    
    # Validate all users in splits exist
    await _validate_users_in_splits(db, expense_in.splits_in)
    
    total_percentage = Decimal("0.00")
    current_total = Decimal("0.00")
    splits = []
    
    for split_in in expense_in.splits_in:
        if not (split_in.share_percentage and Decimal("0") < split_in.share_percentage <= Decimal("100")):
            raise InvalidOperationError(
                f"Invalid percentage {split_in.share_percentage} for user {split_in.user_id}."
            )
            
        total_percentage += split_in.share_percentage
        owed_amount = round_money(db_expense.total_amount * (split_in.share_percentage / Decimal("100")))
        current_total += owed_amount
        
        splits.append(ExpenseSplitModel(
            user_id=split_in.user_id,
            owed_amount=owed_amount,
            share_percentage=split_in.share_percentage
        ))
    
    if round_money(total_percentage) != Decimal("100.00"):
        raise InvalidOperationError(f"Sum of percentages ({total_percentage}%) is not 100%.")
    
    # Adjust for rounding differences
    if current_total != db_expense.total_amount and splits:
        diff = db_expense.total_amount - current_total
        splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff)
    
    return splits


async def _create_shares_splits(
    db: AsyncSession,
    db_expense: ExpenseModel,
    expense_in: ExpenseCreate,
    round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
    """Creates splits based on shares."""
    
    if not expense_in.splits_in:
        raise InvalidOperationError("Splits data is required for SHARES split type.")
    
    # Validate all users in splits exist
    await _validate_users_in_splits(db, expense_in.splits_in)
    
    # Calculate total shares
    total_shares = sum(s.share_units for s in expense_in.splits_in if s.share_units and s.share_units > 0)
    if total_shares == 0:
        raise InvalidOperationError("Total shares cannot be zero for SHARES split.")
    
    splits = []
    current_total = Decimal("0.00")
    
    for split_in in expense_in.splits_in:
        if not (split_in.share_units and split_in.share_units > 0):
            raise InvalidOperationError(f"Invalid share units for user {split_in.user_id}.")
            
        share_ratio = Decimal(split_in.share_units) / Decimal(total_shares)
        owed_amount = round_money(db_expense.total_amount * share_ratio)
        current_total += owed_amount
        
        splits.append(ExpenseSplitModel(
            user_id=split_in.user_id,
            owed_amount=owed_amount,
            share_units=split_in.share_units
        ))
    
    # Adjust for rounding differences
    if current_total != db_expense.total_amount and splits:
        diff = db_expense.total_amount - current_total
        splits[-1].owed_amount = round_money(splits[-1].owed_amount + diff)
    
    return splits


async def _create_item_based_splits(
    db: AsyncSession,
    db_expense: ExpenseModel,
    expense_in: ExpenseCreate,
    round_money: Callable[[Decimal], Decimal]
) -> PyList[ExpenseSplitModel]:
    """Creates splits based on items in a shopping list."""
    
    if not expense_in.list_id:
        raise InvalidOperationError("ITEM_BASED expenses must be associated with a list_id.")
    
    if expense_in.splits_in:
        logger.warning("Ignoring provided 'splits_in' data for ITEM_BASED expense.")
    
    # Build query to fetch relevant items
    items_query = select(ItemModel).where(ItemModel.list_id == expense_in.list_id)
    if expense_in.item_id:
        items_query = items_query.where(ItemModel.id == expense_in.item_id)
    else:
        items_query = items_query.where(ItemModel.price.isnot(None) & (ItemModel.price > Decimal("0")))
    
    # Load items with their adders
    items_result = await db.execute(items_query.options(selectinload(ItemModel.added_by_user)))
    relevant_items = items_result.scalars().all()
    
    if not relevant_items:
        error_msg = (
            f"Specified item ID {expense_in.item_id} not found in list {expense_in.list_id}."
            if expense_in.item_id else
            f"List {expense_in.list_id} has no priced items to base the expense on."
        )
        raise InvalidOperationError(error_msg)
    
    # Aggregate owed amounts by user
    calculated_total = Decimal("0.00")
    user_owed_amounts = defaultdict(Decimal)
    processed_items = 0
    
    for item in relevant_items:
        if item.price is None or item.price <= Decimal("0"):
            if expense_in.item_id:
                raise InvalidOperationError(
                    f"Item ID {expense_in.item_id} must have a positive price for ITEM_BASED expense."
                )
            continue
        
        if not item.added_by_user:
            logger.error(f"Item ID {item.id} is missing added_by_user relationship.")
            raise InvalidOperationError(f"Data integrity issue: Item {item.id} is missing adder information.")
        
        calculated_total += item.price
        user_owed_amounts[item.added_by_user.id] += item.price
        processed_items += 1
    
    if processed_items == 0:
        raise InvalidOperationError(
            f"No items with positive prices found in list {expense_in.list_id} to create ITEM_BASED expense."
        )
    
    # Validate total matches calculated total
    if round_money(calculated_total) != db_expense.total_amount:
        raise InvalidOperationError(
            f"Expense total amount ({db_expense.total_amount}) does not match the "
            f"calculated total from item prices ({calculated_total})."
        )
    
    # Create splits based on aggregated amounts
    splits = []
    for user_id, owed_amount in user_owed_amounts.items():
        splits.append(ExpenseSplitModel(
            user_id=user_id,
            owed_amount=round_money(owed_amount)
        ))
    
    return splits


async def _validate_users_in_splits(db: AsyncSession, splits_in: PyList[ExpenseSplitCreate]) -> None:
    """Validates that all users in the splits exist."""
    
    user_ids_in_split = [s.user_id for s in splits_in]
    user_results = await db.execute(select(UserModel.id).where(UserModel.id.in_(user_ids_in_split)))
    found_user_ids = {row[0] for row in user_results}
    
    if len(found_user_ids) != len(user_ids_in_split):
        missing_user_ids = set(user_ids_in_split) - found_user_ids
        raise UserNotFoundError(identifier=f"users in split data: {list(missing_user_ids)}")


async def get_expense_by_id(db: AsyncSession, expense_id: int) -> Optional[ExpenseModel]:
    result = await db.execute(
        select(ExpenseModel)
        .options(
            selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)),
            selectinload(ExpenseModel.paid_by_user),
            selectinload(ExpenseModel.list),
            selectinload(ExpenseModel.group),
            selectinload(ExpenseModel.item)
        )
        .where(ExpenseModel.id == expense_id)
    )
    return result.scalars().first()

async def get_expenses_for_list(db: AsyncSession, list_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
    result = await db.execute(
        select(ExpenseModel)
        .where(ExpenseModel.list_id == list_id)
        .order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
        .offset(skip).limit(limit)
        .options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user))) # Also load user for each split
    )
    return result.scalars().all()

async def get_expenses_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
    result = await db.execute(
        select(ExpenseModel)
        .where(ExpenseModel.group_id == group_id)
        .order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
        .offset(skip).limit(limit)
        .options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)))
    )
    return result.scalars().all()

async def update_expense(db: AsyncSession, expense_db: ExpenseModel, expense_in: ExpenseUpdate) -> ExpenseModel:
    """
    Updates an existing expense. 
    Only allows updates to description, currency, and expense_date to avoid split complexities.
    Requires version matching for optimistic locking.
    """
    if expense_db.version != expense_in.version:
        raise InvalidOperationError(
            f"Expense '{expense_db.description}' (ID: {expense_db.id}) has been modified. "
            f"Your version is {expense_in.version}, current version is {expense_db.version}. Please refresh.",
            # status_code=status.HTTP_409_CONFLICT # This would be for the API layer to set
        )

    update_data = expense_in.model_dump(exclude_unset=True, exclude={"version"}) # Exclude version itself from data
    
    # Fields that are safe to update without affecting splits or core logic
    allowed_to_update = {"description", "currency", "expense_date"}
    
    updated_something = False
    for field, value in update_data.items():
        if field in allowed_to_update:
            setattr(expense_db, field, value)
            updated_something = True
        else:
            # If any other field is present in the update payload, it's an invalid operation for this simple update
            raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed.")

    if not updated_something and not expense_in.model_fields_set.intersection(allowed_to_update):
        # No actual updatable fields were provided in the payload, even if others (like version) were.
        # This could be a non-issue, or an indication of a misuse of the endpoint.
        # For now, if only version was sent, we still increment if it matched.
        pass # Or raise InvalidOperationError("No updatable fields provided.")

    expense_db.version += 1
    expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp

    try:
        await db.commit()
        await db.refresh(expense_db)
    except Exception as e:
        await db.rollback()
        # Consider specific DB error types if needed
        raise InvalidOperationError(f"Failed to update expense: {str(e)}")
        
    return expense_db

async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None:
    """
    Deletes an expense. Requires version matching if expected_version is provided.
    Associated ExpenseSplits are cascade deleted by the database foreign key constraint.
    """
    if expected_version is not None and expense_db.version != expected_version:
        raise InvalidOperationError(
            f"Expense '{expense_db.description}' (ID: {expense_db.id}) cannot be deleted. "
            f"Your expected version {expected_version} does not match current version {expense_db.version}. Please refresh.",
            # status_code=status.HTTP_409_CONFLICT
        )
    
    await db.delete(expense_db)
    try:
        await db.commit()
    except Exception as e:
        await db.rollback()
        raise InvalidOperationError(f"Failed to delete expense: {str(e)}")
    return None

# Note: The InvalidOperationError is a simple ValueError placeholder.
# For API endpoints, these should be translated to appropriate HTTPExceptions.
# Ensure app.core.exceptions has proper HTTP error classes if needed.