# app/crud/expense.py
import logging # Add logging import
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError # Added import
from decimal import Decimal, ROUND_HALF_UP, InvalidOperation as DecimalInvalidOperation
from typing import Callable, List as PyList, Optional, Sequence, Dict, defaultdict, Any
from datetime import datetime, timezone # Added timezone

from app.models import (
    Expense as ExpenseModel,
    ExpenseSplit as ExpenseSplitModel,
    User as UserModel,
    List as ListModel,
    Group as GroupModel,
    UserGroup as UserGroupModel,
    SplitTypeEnum,
    Item as ItemModel,
    ExpenseOverallStatusEnum, # Added
    ExpenseSplitStatusEnum,   # Added
)
from app.schemas.expense import ExpenseCreate, ExpenseSplitCreate, ExpenseUpdate # Removed unused ExpenseUpdate
from app.core.exceptions import (
    # Using existing specific exceptions where possible
    ListNotFoundError,
    GroupNotFoundError,
    UserNotFoundError,
    InvalidOperationError, # Import the new exception
    DatabaseConnectionError, # Added
    DatabaseIntegrityError,  # Added
    DatabaseQueryError,      # Added
    DatabaseTransactionError,# Added
    ExpenseOperationError    # Added specific exception
)
from app.models import RecurrencePattern

# Placeholder for InvalidOperationError if not defined in app.core.exceptions
# This should be a proper HTTPException subclass if used in API layer
# class CrudInvalidOperationError(ValueError): # For internal CRUD validation logic
#     pass

logger = logging.getLogger(__name__) # Initialize logger

def _round_money(amount: Decimal) -> Decimal:
    """Rounds a Decimal to two decimal places using ROUND_HALF_UP."""
    return amount.quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)

async def get_users_for_splitting(db: AsyncSession, expense_group_id: Optional[int], expense_list_id: Optional[int], expense_paid_by_user_id: int) -> PyList[UserModel]:
    """
    Determines the list of users an expense should be split amongst.
    Priority: Group members (if group_id), then List's group members or creator (if list_id).
    Fallback to only the payer if no other context yields users.
    """
    users_to_split_with: PyList[UserModel] = []
    processed_user_ids = set()

    async def _add_user(user: Optional[UserModel]):
        if user and user.id not in processed_user_ids:
            users_to_split_with.append(user)
            processed_user_ids.add(user.id)

    if expense_group_id:
        group_result = await db.execute(
            select(GroupModel).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user)))
            .where(GroupModel.id == expense_group_id)
        )
        group = group_result.scalars().first()
        if not group:
            raise GroupNotFoundError(expense_group_id)
        for assoc in group.member_associations:
            await _add_user(assoc.user)
    
    elif expense_list_id: # Only if group_id was not primary context
        list_result = await db.execute(
            select(ListModel)
            .options(
                selectinload(ListModel.group).options(selectinload(GroupModel.member_associations).options(selectinload(UserGroupModel.user))),
                selectinload(ListModel.creator) 
            )
            .where(ListModel.id == expense_list_id)
        )
        db_list = list_result.scalars().first()
        if not db_list:
            raise ListNotFoundError(expense_list_id)
        
        if db_list.group: 
            for assoc in db_list.group.member_associations:
                await _add_user(assoc.user)
        elif db_list.creator: 
            await _add_user(db_list.creator)
    
    if not users_to_split_with: 
        payer_user = await db.get(UserModel, expense_paid_by_user_id)
        if not payer_user:
            # This should have been caught earlier if paid_by_user_id was validated before calling this helper
            raise UserNotFoundError(user_id=expense_paid_by_user_id) 
        await _add_user(payer_user)

    if not users_to_split_with:
        # This should ideally not be reached if payer is always a fallback
        raise InvalidOperationError("Could not determine any users for splitting the expense.")
        
    return users_to_split_with


async def create_expense(db: AsyncSession, expense_in: ExpenseCreate, current_user_id: int) -> ExpenseModel:
    """Creates a new expense and its associated splits.
    
    Args:
        db: Database session
        expense_in: Expense creation data
        current_user_id: ID of the user creating the expense
        
    Returns:
        The created expense with splits
        
    Raises:
        UserNotFoundError: If payer or split users don't exist
        ListNotFoundError: If specified list doesn't exist
        GroupNotFoundError: If specified group doesn't exist
        InvalidOperationError: For various validation failures
    """
    try:
        async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
            # 1. Validate payer
            payer = await db.get(UserModel, expense_in.paid_by_user_id)
            if not payer:
                raise UserNotFoundError(user_id=expense_in.paid_by_user_id, identifier="Payer")

            # 2. Context Resolution and Validation (now part of the transaction)
            if not expense_in.list_id and not expense_in.group_id and not expense_in.item_id:
                raise InvalidOperationError("Expense must be associated with a list, a group, or an item.")
            
            final_group_id = await _resolve_expense_context(db, expense_in)
            # Further validation for item_id if provided
            db_item_instance = None
            if expense_in.item_id:
                db_item_instance = await db.get(ItemModel, expense_in.item_id)
                if not db_item_instance:
                    raise InvalidOperationError(f"Item with ID {expense_in.item_id} not found.")
                # Potentially link item's list/group if not already set on expense_in
                if db_item_instance.list_id and not expense_in.list_id:
                    expense_in.list_id = db_item_instance.list_id
                    # Re-resolve context if list_id was derived from item
                    final_group_id = await _resolve_expense_context(db, expense_in)

            # Create recurrence pattern if this is a recurring expense
            recurrence_pattern = None
            if expense_in.is_recurring and expense_in.recurrence_pattern:
                recurrence_pattern = RecurrencePattern(
                    type=expense_in.recurrence_pattern.type,
                    interval=expense_in.recurrence_pattern.interval,
                    days_of_week=expense_in.recurrence_pattern.days_of_week,
                    end_date=expense_in.recurrence_pattern.end_date,
                    max_occurrences=expense_in.recurrence_pattern.max_occurrences,
                    created_at=datetime.now(timezone.utc),
                    updated_at=datetime.now(timezone.utc)
                )
                db.add(recurrence_pattern)
                await db.flush()

            # 3. Create the ExpenseModel instance
            db_expense = ExpenseModel(
                description=expense_in.description,
                total_amount=_round_money(expense_in.total_amount),
                currency=expense_in.currency or "USD",
                expense_date=expense_in.expense_date or datetime.now(timezone.utc),
                split_type=expense_in.split_type,
                list_id=expense_in.list_id,
                group_id=final_group_id, # Use resolved group_id
                item_id=expense_in.item_id,
                paid_by_user_id=expense_in.paid_by_user_id,
                created_by_user_id=current_user_id,
                overall_settlement_status=ExpenseOverallStatusEnum.unpaid,
                is_recurring=expense_in.is_recurring,
                recurrence_pattern=recurrence_pattern,
                next_occurrence=expense_in.expense_date if expense_in.is_recurring else None
            )
            db.add(db_expense)
            await db.flush()  # Get expense ID

            # 4. Generate splits (passing current_user_id through kwargs if needed by specific split types)
            splits_to_create = await _generate_expense_splits(
                db=db, 
                expense_model=db_expense, 
                expense_in=expense_in,
                current_user_id=current_user_id # Pass for item-based splits needing creator info
            )
            
            for split_model in splits_to_create:
                split_model.expense_id = db_expense.id # Set FK after db_expense has ID
            db.add_all(splits_to_create)
            await db.flush() # Persist splits

            # 5. Re-fetch the expense with all necessary relationships for the response
            stmt = (
                select(ExpenseModel)
                .where(ExpenseModel.id == db_expense.id)
                .options(
                    selectinload(ExpenseModel.paid_by_user),
                    selectinload(ExpenseModel.created_by_user), # If you have this relationship
                    selectinload(ExpenseModel.list),
                    selectinload(ExpenseModel.group),
                    selectinload(ExpenseModel.item),
                    selectinload(ExpenseModel.splits).selectinload(ExpenseSplitModel.user)
                )
            )
            result = await db.execute(stmt)
            loaded_expense = result.scalar_one_or_none()

            if loaded_expense is None:
                # The context manager will handle rollback if an exception is raised.
                # await transaction.rollback() # Should be handled by context manager
                raise ExpenseOperationError("Failed to load expense after creation.")

            # await transaction.commit() # Explicit commit removed, context manager handles it.
            return loaded_expense

    except (UserNotFoundError, ListNotFoundError, GroupNotFoundError, InvalidOperationError) as e:
        # These are business logic validation errors, re-raise them.
        # If a transaction was started, the context manager handles rollback.
        raise
    except IntegrityError as e:
        # Context manager handles rollback.
        logger.error(f"Database integrity error during expense creation: {str(e)}", exc_info=True)
        raise DatabaseIntegrityError(f"Failed to save expense due to database integrity issue: {str(e)}")
    except OperationalError as e:
        logger.error(f"Database connection error during expense creation: {str(e)}", exc_info=True)
        raise DatabaseConnectionError(f"Database connection error during expense creation: {str(e)}")
    except SQLAlchemyError as e:
        # Context manager handles rollback.
        logger.error(f"Unexpected SQLAlchemy error during expense creation: {str(e)}", exc_info=True)
        raise DatabaseTransactionError(f"Failed to save expense due to a database transaction error: {str(e)}")


async def _resolve_expense_context(db: AsyncSession, expense_in: ExpenseCreate) -> Optional[int]:
    """Resolves and validates the expense's context (list and group).
    
    Returns the final group_id for the expense after validation.
    """
    final_group_id = expense_in.group_id
    
    # If list_id is provided, validate it and potentially derive group_id
    if expense_in.list_id:
        list_obj = await db.get(ListModel, expense_in.list_id)
        if not list_obj:
            raise ListNotFoundError(expense_in.list_id)
            
        # If list belongs to a group, verify consistency or inherit group_id
        if list_obj.group_id:
            if expense_in.group_id and list_obj.group_id != expense_in.group_id:
                raise InvalidOperationError(
                    f"List {expense_in.list_id} belongs to group {list_obj.group_id}, "
                    f"but expense was specified for group {expense_in.group_id}."
                )
            final_group_id = list_obj.group_id  # Prioritize list's group
    
    # If only group_id is provided (no list_id), validate group_id
    elif final_group_id:
        group_obj = await db.get(GroupModel, final_group_id)
        if not group_obj:
            raise GroupNotFoundError(final_group_id)
            
    return final_group_id


async def _generate_expense_splits(
    db: AsyncSession,
    expense_model: ExpenseModel,
    expense_in: ExpenseCreate,
    **kwargs: Any
) -> PyList[ExpenseSplitModel]:
    """Generates appropriate expense splits based on split type."""
    
    splits_to_create: PyList[ExpenseSplitModel] = []
    
    # Pass db to split creation helpers if they need to fetch more data (e.g., item details for item-based)
    common_args = {"db": db, "expense_model": expense_model, "expense_in": expense_in, "round_money_func": _round_money, "kwargs": kwargs}

    # Create splits based on the split type
    if expense_in.split_type == SplitTypeEnum.EQUAL:
        splits_to_create = await _create_equal_splits(**common_args)
        
    elif expense_in.split_type == SplitTypeEnum.EXACT_AMOUNTS:
        splits_to_create = await _create_exact_amount_splits(**common_args)
        
    elif expense_in.split_type == SplitTypeEnum.PERCENTAGE:
        splits_to_create = await _create_percentage_splits(**common_args)
        
    elif expense_in.split_type == SplitTypeEnum.SHARES:
        splits_to_create = await _create_shares_splits(**common_args)
        
    elif expense_in.split_type == SplitTypeEnum.ITEM_BASED:
        splits_to_create = await _create_item_based_splits(**common_args)
        
    else:
        raise InvalidOperationError(f"Unsupported split type: {expense_in.split_type.value}")
    
    if not splits_to_create:
        raise InvalidOperationError("No expense splits were generated.")
        
    return splits_to_create


async def _create_equal_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
    """Creates equal splits among users."""
    
    users_for_splitting = await get_users_for_splitting(
        db, expense_model.group_id, expense_model.list_id, expense_model.paid_by_user_id
    )
    if not users_for_splitting:
        raise InvalidOperationError("No users found for EQUAL split.")
    
    num_users = len(users_for_splitting)
    amount_per_user = round_money_func(expense_model.total_amount / Decimal(num_users))
    remainder = expense_model.total_amount - (amount_per_user * num_users)
    
    splits = []
    for i, user in enumerate(users_for_splitting):
        split_amount = amount_per_user
        if i == 0 and remainder != Decimal('0'):
            split_amount = round_money_func(amount_per_user + remainder)
        
        splits.append(ExpenseSplitModel(
            user_id=user.id,
            owed_amount=split_amount,
            status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
        ))
    
    return splits


async def _create_exact_amount_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
    """Creates splits with exact amounts."""
    
    if not expense_in.splits_in:
        raise InvalidOperationError("Splits data is required for EXACT_AMOUNTS split type.")
    
    # Validate all users in splits exist
    await _validate_users_in_splits(db, expense_in.splits_in)
    
    current_total = Decimal("0.00")
    splits = []
    
    for split_in in expense_in.splits_in:
        if split_in.owed_amount <= Decimal('0'):
            raise InvalidOperationError(f"Owed amount for user {split_in.user_id} must be positive.")
            
        rounded_amount = round_money_func(split_in.owed_amount)
        current_total += rounded_amount
        
        splits.append(ExpenseSplitModel(
            user_id=split_in.user_id,
            owed_amount=rounded_amount,
            status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
        ))
    
    if round_money_func(current_total) != expense_model.total_amount:
        raise InvalidOperationError(
            f"Sum of exact split amounts ({current_total}) != expense total ({expense_model.total_amount})."
        )
    
    return splits


async def _create_percentage_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
    """Creates splits based on percentages."""
    
    if not expense_in.splits_in:
        raise InvalidOperationError("Splits data is required for PERCENTAGE split type.")
    
    # Validate all users in splits exist
    await _validate_users_in_splits(db, expense_in.splits_in)
    
    total_percentage = Decimal("0.00")
    current_total = Decimal("0.00")
    splits = []
    
    for split_in in expense_in.splits_in:
        if not (split_in.share_percentage and Decimal("0") < split_in.share_percentage <= Decimal("100")):
            raise InvalidOperationError(
                f"Invalid percentage {split_in.share_percentage} for user {split_in.user_id}."
            )
            
        total_percentage += split_in.share_percentage
        owed_amount = round_money_func(expense_model.total_amount * (split_in.share_percentage / Decimal("100")))
        current_total += owed_amount
        
        splits.append(ExpenseSplitModel(
            user_id=split_in.user_id,
            owed_amount=owed_amount,
            share_percentage=split_in.share_percentage,
            status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
        ))
    
    if round_money_func(total_percentage) != Decimal("100.00"):
        raise InvalidOperationError(f"Sum of percentages ({total_percentage}%) is not 100%.")
    
    # Adjust for rounding differences
    if current_total != expense_model.total_amount and splits:
        diff = expense_model.total_amount - current_total
        splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
    
    return splits


async def _create_shares_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
    """Creates splits based on shares."""
    
    if not expense_in.splits_in:
        raise InvalidOperationError("Splits data is required for SHARES split type.")
    
    # Validate all users in splits exist
    await _validate_users_in_splits(db, expense_in.splits_in)
    
    # Calculate total shares
    total_shares = sum(s.share_units for s in expense_in.splits_in if s.share_units and s.share_units > 0)
    if total_shares == 0:
        raise InvalidOperationError("Total shares cannot be zero for SHARES split.")
    
    splits = []
    current_total = Decimal("0.00")
    
    for split_in in expense_in.splits_in:
        if not (split_in.share_units and split_in.share_units > 0):
            raise InvalidOperationError(f"Invalid share units for user {split_in.user_id}.")
            
        share_ratio = Decimal(split_in.share_units) / Decimal(total_shares)
        owed_amount = round_money_func(expense_model.total_amount * share_ratio)
        current_total += owed_amount
        
        splits.append(ExpenseSplitModel(
            user_id=split_in.user_id,
            owed_amount=owed_amount,
            share_units=split_in.share_units,
            status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
        ))
    
    # Adjust for rounding differences
    if current_total != expense_model.total_amount and splits:
        diff = expense_model.total_amount - current_total
        splits[-1].owed_amount = round_money_func(splits[-1].owed_amount + diff)
    
    return splits


async def _create_item_based_splits(db: AsyncSession, expense_model: ExpenseModel, expense_in: ExpenseCreate, round_money_func: Callable[[Decimal], Decimal], **kwargs: Any) -> PyList[ExpenseSplitModel]:
    """Creates splits based on items in a shopping list."""
    
    if not expense_model.list_id:
        raise InvalidOperationError("ITEM_BASED expenses must be associated with a list_id.")
    
    if expense_in.splits_in:
        logger.warning("Ignoring provided 'splits_in' data for ITEM_BASED expense.")
    
    # Build query to fetch relevant items
    items_query = select(ItemModel).where(ItemModel.list_id == expense_model.list_id)
    if expense_model.item_id:
        items_query = items_query.where(ItemModel.id == expense_model.item_id)
    else:
        items_query = items_query.where(ItemModel.price.isnot(None) & (ItemModel.price > Decimal("0")))
    
    # Load items with their adders
    items_result = await db.execute(items_query.options(selectinload(ItemModel.added_by_user)))
    relevant_items = items_result.scalars().all()
    
    if not relevant_items:
        error_msg = (
            f"Specified item ID {expense_model.item_id} not found in list {expense_model.list_id}."
            if expense_model.item_id else
            f"List {expense_model.list_id} has no priced items to base the expense on."
        )
        raise InvalidOperationError(error_msg)
    
    # Aggregate owed amounts by user
    calculated_total = Decimal("0.00")
    user_owed_amounts = defaultdict(Decimal)
    processed_items = 0
    
    for item in relevant_items:
        if item.price is None or item.price <= Decimal("0"):
            if expense_model.item_id:
                raise InvalidOperationError(
                    f"Item ID {expense_model.item_id} must have a positive price for ITEM_BASED expense."
                )
            continue
        
        if not item.added_by_user:
            logger.error(f"Item ID {item.id} is missing added_by_user relationship.")
            raise InvalidOperationError(f"Data integrity issue: Item {item.id} is missing adder information.")
        
        calculated_total += item.price
        user_owed_amounts[item.added_by_user.id] += item.price
        processed_items += 1
    
    if processed_items == 0:
        raise InvalidOperationError(
            f"No items with positive prices found in list {expense_model.list_id} to create ITEM_BASED expense."
        )
    
    # Validate total matches calculated total
    if round_money_func(calculated_total) != expense_model.total_amount:
        raise InvalidOperationError(
            f"Expense total amount ({expense_model.total_amount}) does not match the "
            f"calculated total from item prices ({calculated_total})."
        )
    
    # Create splits based on aggregated amounts
    splits = []
    for user_id, owed_amount in user_owed_amounts.items():
        splits.append(ExpenseSplitModel(
            user_id=user_id,
            owed_amount=round_money_func(owed_amount),
            status=ExpenseSplitStatusEnum.unpaid # Explicitly set default status
        ))
    
    return splits


async def _validate_users_in_splits(db: AsyncSession, splits_in: PyList[ExpenseSplitCreate]) -> None:
    """Validates that all users in the splits exist."""
    
    user_ids_in_split = [s.user_id for s in splits_in]
    user_results = await db.execute(select(UserModel.id).where(UserModel.id.in_(user_ids_in_split)))
    found_user_ids = {row[0] for row in user_results}
    
    if len(found_user_ids) != len(user_ids_in_split):
        missing_user_ids = set(user_ids_in_split) - found_user_ids
        raise UserNotFoundError(identifier=f"users in split data: {list(missing_user_ids)}")


async def get_expense_by_id(db: AsyncSession, expense_id: int) -> Optional[ExpenseModel]:
    result = await db.execute(
        select(ExpenseModel)
        .options(
            selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)),
            selectinload(ExpenseModel.paid_by_user),
            selectinload(ExpenseModel.list),
            selectinload(ExpenseModel.group),
            selectinload(ExpenseModel.item)
        )
        .where(ExpenseModel.id == expense_id)
    )
    return result.scalars().first()

async def get_expenses_for_list(db: AsyncSession, list_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
    result = await db.execute(
        select(ExpenseModel)
        .where(ExpenseModel.list_id == list_id)
        .order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
        .offset(skip).limit(limit)
        .options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user))) # Also load user for each split
    )
    return result.scalars().all()

async def get_expenses_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[ExpenseModel]:
    result = await db.execute(
        select(ExpenseModel)
        .where(ExpenseModel.group_id == group_id)
        .order_by(ExpenseModel.expense_date.desc(), ExpenseModel.created_at.desc())
        .offset(skip).limit(limit)
        .options(selectinload(ExpenseModel.splits).options(selectinload(ExpenseSplitModel.user)))
    )
    return result.scalars().all()

async def update_expense(db: AsyncSession, expense_db: ExpenseModel, expense_in: ExpenseUpdate) -> ExpenseModel:
    """
    Updates an existing expense.
    Only allows updates to description, currency, and expense_date to avoid split complexities.
    Requires version matching for optimistic locking.
    """
    if expense_db.version != expense_in.version:
        raise InvalidOperationError(
            f"Expense '{expense_db.description}' (ID: {expense_db.id}) has been modified. "
            f"Your version is {expense_in.version}, current version is {expense_db.version}. Please refresh.",
            # status_code=status.HTTP_409_CONFLICT # This would be for the API layer to set
        )

    update_data = expense_in.model_dump(exclude_unset=True, exclude={"version"}) # Exclude version itself from data
    
    # Fields that are safe to update without affecting splits or core logic
    allowed_to_update = {"description", "currency", "expense_date"}
    
    updated_something = False
    for field, value in update_data.items():
        if field in allowed_to_update:
            setattr(expense_db, field, value)
            updated_something = True
        else:
            # If any other field is present in the update payload, it's an invalid operation for this simple update
            raise InvalidOperationError(f"Field '{field}' cannot be updated. Only {', '.join(allowed_to_update)} are allowed.")

    if not updated_something and not expense_in.model_fields_set.intersection(allowed_to_update):
        # No actual updatable fields were provided in the payload, even if others (like version) were.
        # This could be a non-issue, or an indication of a misuse of the endpoint.
        # For now, if only version was sent, we still increment if it matched.
        pass # Or raise InvalidOperationError("No updatable fields provided.")

    try:
        async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
            expense_db.version += 1
            expense_db.updated_at = datetime.now(timezone.utc) # Manually update timestamp
            # db.add(expense_db) # Not strictly necessary as expense_db is already tracked by the session
            
            await db.flush() # Persist changes to the DB and run constraints
            await db.refresh(expense_db) # Refresh the object from the DB
        return expense_db
    except InvalidOperationError: # Re-raise validation errors to be handled by the caller
        raise
    except IntegrityError as e:
        logger.error(f"Database integrity error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
        # The transaction context manager (begin_nested/begin) handles rollback.
        raise DatabaseIntegrityError(f"Failed to update expense ID {expense_db.id} due to database integrity issue.") from e
    except SQLAlchemyError as e: # Catch other SQLAlchemy errors
        logger.error(f"Database transaction error during expense update for ID {expense_db.id}: {str(e)}", exc_info=True)
        # The transaction context manager (begin_nested/begin) handles rollback.
        raise DatabaseTransactionError(f"Failed to update expense ID {expense_db.id} due to a database transaction error.") from e
    # No generic Exception catch here, let other unexpected errors propagate if not SQLAlchemy related.


async def delete_expense(db: AsyncSession, expense_db: ExpenseModel, expected_version: Optional[int] = None) -> None:
    """
    Deletes an expense. Requires version matching if expected_version is provided.
    Associated ExpenseSplits are cascade deleted by the database foreign key constraint.
    """
    if expected_version is not None and expense_db.version != expected_version:
        raise InvalidOperationError(
            f"Expense '{expense_db.description}' (ID: {expense_db.id}) cannot be deleted. "
            f"Your expected version {expected_version} does not match current version {expense_db.version}. Please refresh.",
            # status_code=status.HTTP_409_CONFLICT
        )
    
    try:
        async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
            await db.delete(expense_db)
            await db.flush() # Ensure the delete operation is sent to the database
    except InvalidOperationError: # Re-raise validation errors
        raise
    except IntegrityError as e:
        logger.error(f"Database integrity error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
        # The transaction context manager (begin_nested/begin) handles rollback.
        raise DatabaseIntegrityError(f"Failed to delete expense ID {expense_db.id} due to database integrity issue.") from e
    except SQLAlchemyError as e: # Catch other SQLAlchemy errors
        logger.error(f"Database transaction error during expense deletion for ID {expense_db.id}: {str(e)}", exc_info=True)
        # The transaction context manager (begin_nested/begin) handles rollback.
        raise DatabaseTransactionError(f"Failed to delete expense ID {expense_db.id} due to a database transaction error.") from e
    return None

# Note: The InvalidOperationError is a simple ValueError placeholder.
# For API endpoints, these should be translated to appropriate HTTPExceptions.
# Ensure app.core.exceptions has proper HTTP error classes if needed.