import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError
from decimal import Decimal, ROUND_HALF_UP
from datetime import datetime, timezone
from typing import List as PyList, Optional

from app.crud.expense import (
    create_expense,
    get_expense_by_id,
    get_expenses_for_list,
    get_expenses_for_group,
    update_expense,
    delete_expense,
    get_users_for_splitting
)
from app.schemas.expense import ExpenseCreate, ExpenseUpdate, ExpenseSplitCreate, ExpenseRead
from app.models import (
    Expense as ExpenseModel,
    ExpenseSplit as ExpenseSplitModel,
    User as UserModel,
    List as ListModel,
    Group as GroupModel,
    UserGroup as UserGroupModel,
    Item as ItemModel,
    SplitTypeEnum,
    ExpenseOverallStatusEnum, # Added
    ExpenseSplitStatusEnum    # Added
)
from app.core.exceptions import (
    ListNotFoundError,
    GroupNotFoundError,
    UserNotFoundError,
    InvalidOperationError,
    ExpenseNotFoundError,
    DatabaseTransactionError,
    ConflictError
)

# General Fixtures
@pytest.fixture
def mock_db_session():
    session = AsyncMock()
    session.begin_nested = AsyncMock() # For nested transactions within functions
    session.commit = AsyncMock()
    session.rollback = AsyncMock()
    session.refresh = AsyncMock()
    session.add = MagicMock()
    session.delete = MagicMock()
    session.execute = AsyncMock()
    session.get = AsyncMock()
    session.flush = AsyncMock()
    session.in_transaction = MagicMock(return_value=False)
    # Mock session.begin() to return an async context manager
    mock_transaction_context = AsyncMock()
    session.begin = MagicMock(return_value=mock_transaction_context)
    return session

@pytest.fixture
def basic_user_model():
    return UserModel(id=1, name="Test User", email="test@example.com", version=1)

@pytest.fixture
def another_user_model():
    return UserModel(id=2, name="Another User", email="another@example.com", version=1)

@pytest.fixture
def basic_group_model(basic_user_model, another_user_model):
    group = GroupModel(id=1, name="Test Group", version=1)
    # Simulate member_associations for get_users_for_splitting if needed directly
    # group.member_associations = [UserGroupModel(user_id=1, group_id=1, user=basic_user_model), UserGroupModel(user_id=2, group_id=1, user=another_user_model)]
    return group

@pytest.fixture
def basic_list_model(basic_group_model, basic_user_model):
    return ListModel(id=1, name="Test List", group_id=basic_group_model.id, group=basic_group_model, created_by_id=basic_user_model.id, creator=basic_user_model, version=1)

@pytest.fixture
def expense_create_data_equal_split_list_ctx(basic_list_model, basic_user_model):
    return ExpenseCreate(
        description="Grocery run",
        total_amount=Decimal("30.00"),
        currency="USD",
        expense_date=datetime.now(timezone.utc).date(),
        split_type=SplitTypeEnum.EQUAL,
        list_id=basic_list_model.id,
        group_id=None, # Derived from list
        item_id=None,
        paid_by_user_id=basic_user_model.id,
        splits_in=None
    )

@pytest.fixture
def expense_create_data_equal_split_group_ctx(basic_group_model, basic_user_model):
    return ExpenseCreate(
        description="Movies",
        total_amount=Decimal("50.00"),
        currency="USD",
        expense_date=datetime.now(timezone.utc).date(),
        split_type=SplitTypeEnum.EQUAL,
        list_id=None,
        group_id=basic_group_model.id,
        item_id=None,
        paid_by_user_id=basic_user_model.id,
        splits_in=None
    )

@pytest.fixture
def expense_create_data_exact_split(basic_group_model, basic_user_model, another_user_model):
    return ExpenseCreate(
        description="Dinner",
        total_amount=Decimal("100.00"),
        expense_date=datetime.now(timezone.utc).date(),
        currency="USD",
        split_type=SplitTypeEnum.EXACT_AMOUNTS,
        group_id=basic_group_model.id,
        paid_by_user_id=basic_user_model.id,
        splits_in=[
            ExpenseSplitCreate(user_id=basic_user_model.id, owed_amount=Decimal("60.00")),
            ExpenseSplitCreate(user_id=another_user_model.id, owed_amount=Decimal("40.00")),
        ]
    )

@pytest.fixture
def expense_update_data():
    return ExpenseUpdate(
        description="Updated Dinner", 
        total_amount=Decimal("120.00"),
        version=1 # Ensure version is provided for updates
    )

@pytest.fixture
def db_expense_model(expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model):
    expense = ExpenseModel(
        id=1,
        description=expense_create_data_equal_split_group_ctx.description,
        total_amount=expense_create_data_equal_split_group_ctx.total_amount,
        currency=expense_create_data_equal_split_group_ctx.currency,
        expense_date=expense_create_data_equal_split_group_ctx.expense_date,
        split_type=expense_create_data_equal_split_group_ctx.split_type,
        list_id=expense_create_data_equal_split_group_ctx.list_id,
        group_id=expense_create_data_equal_split_group_ctx.group_id,
        group=basic_group_model, # Link to group fixture
        item_id=expense_create_data_equal_split_group_ctx.item_id,
        paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id,
        created_by_user_id=basic_user_model.id,
        paid_by=basic_user_model, 
        created_by_user=basic_user_model, 
        version=1,
        created_at=datetime.now(timezone.utc),
        updated_at=datetime.now(timezone.utc)
    )
    # Simulate splits for an existing expense
    expense.splits = [
        ExpenseSplitModel(id=1, expense_id=1, user_id=basic_user_model.id, owed_amount=Decimal("25.00"), version=1),
        ExpenseSplitModel(id=2, expense_id=1, user_id=2, owed_amount=Decimal("25.00"), version=1) # Assuming another_user_model has id 2
    ]
    return expense

# Tests for get_users_for_splitting
@pytest.mark.asyncio
async def test_get_users_for_splitting_group_context(mock_db_session, basic_group_model, basic_user_model, another_user_model):
    user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id, group_id=basic_group_model.id)
    user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id, group_id=basic_group_model.id)
    basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2]
    
    mock_db_session.get.return_value = basic_group_model # Mock get for group

    users = await get_users_for_splitting(mock_db_session, expense_group_id=basic_group_model.id, expense_list_id=None, expense_paid_by_user_id=basic_user_model.id)
    assert len(users) == 2
    assert basic_user_model in users
    assert another_user_model in users

@pytest.mark.asyncio
async def test_get_users_for_splitting_list_context(mock_db_session, basic_list_model, basic_group_model, basic_user_model, another_user_model):
    user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id, group_id=basic_group_model.id)
    user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id, group_id=basic_group_model.id)
    basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2]
    basic_list_model.group = basic_group_model # Ensure list is associated with the group

    mock_db_session.get.return_value = basic_list_model # Mock get for list

    users = await get_users_for_splitting(mock_db_session, expense_group_id=None, expense_list_id=basic_list_model.id, expense_paid_by_user_id=basic_user_model.id)
    assert len(users) == 2
    assert basic_user_model in users
    assert another_user_model in users

# --- create_expense Tests ---
@pytest.mark.asyncio
async def test_create_expense_equal_split_group_success(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model, another_user_model):
    # Setup mocks
    mock_db_session.get.side_effect = [basic_user_model, basic_group_model] # paid_by_user, then group
    
    # Mock get_users_for_splitting directly
    with patch('app.crud.expense.get_users_for_splitting', new_callable=AsyncMock) as mock_get_users:
        mock_get_users.return_value = [basic_user_model, another_user_model]

        async def mock_refresh(instance, attribute_names=None, with_for_update=None):
            if isinstance(instance, ExpenseModel):
                instance.id = 1 # Simulate ID assignment after flush
                instance.version = 1
                instance.created_at = datetime.now(timezone.utc)
                instance.updated_at = datetime.now(timezone.utc)
                # Simulate splits being added to the session and linked by refresh
                instance.splits = [
                    ExpenseSplitModel(expense_id=instance.id, user_id=basic_user_model.id, owed_amount=Decimal("25.00"), version=1),
                    ExpenseSplitModel(expense_id=instance.id, user_id=another_user_model.id, owed_amount=Decimal("25.00"), version=1)
                ]
            return None
        mock_db_session.refresh.side_effect = mock_refresh

        created_expense = await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, current_user_id=basic_user_model.id)

        mock_db_session.add.assert_called()
        mock_db_session.flush.assert_called_once()
        mock_db_session.refresh.assert_called_once()
        assert created_expense is not None
        assert created_expense.total_amount == expense_create_data_equal_split_group_ctx.total_amount
        assert created_expense.split_type == SplitTypeEnum.EQUAL
        assert len(created_expense.splits) == 2

        expected_amount_per_user = (expense_create_data_equal_split_group_ctx.total_amount / 2).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
        for split in created_expense.splits:
            assert split.owed_amount == expected_amount_per_user
            assert split.status == ExpenseSplitStatusEnum.unpaid # Verify initial split status
        
        assert created_expense.overall_settlement_status == ExpenseOverallStatusEnum.unpaid # Verify initial expense status

@pytest.mark.asyncio
async def test_create_expense_exact_split_success(mock_db_session, expense_create_data_exact_split, basic_user_model, basic_group_model, another_user_model):
    mock_db_session.get.side_effect = [basic_user_model, basic_group_model, basic_user_model, another_user_model] # Payer, Group, User1 in split, User2 in split

    async def mock_refresh(instance, attribute_names=None, with_for_update=None):
        if isinstance(instance, ExpenseModel):
            instance.id = 2
            instance.version = 1
            instance.splits = [
                ExpenseSplitModel(expense_id=instance.id, user_id=basic_user_model.id, owed_amount=Decimal("60.00")),
                ExpenseSplitModel(expense_id=instance.id, user_id=another_user_model.id, owed_amount=Decimal("40.00"))
            ]
        return None
    mock_db_session.refresh.side_effect = mock_refresh
    
    created_expense = await create_expense(mock_db_session, expense_create_data_exact_split, current_user_id=basic_user_model.id)

    mock_db_session.add.assert_called()
    mock_db_session.flush.assert_called_once()
    assert created_expense is not None
    assert created_expense.split_type == SplitTypeEnum.EXACT_AMOUNTS
    assert len(created_expense.splits) == 2
    assert created_expense.splits[0].owed_amount == Decimal("60.00")
    assert created_expense.splits[1].owed_amount == Decimal("40.00")
    for split in created_expense.splits:
        assert split.status == ExpenseSplitStatusEnum.unpaid # Verify initial split status
        
    assert created_expense.overall_settlement_status == ExpenseOverallStatusEnum.unpaid # Verify initial expense status

@pytest.mark.asyncio
async def test_create_expense_payer_not_found(mock_db_session, expense_create_data_equal_split_group_ctx):
    mock_db_session.get.side_effect = [None] # Payer not found, group lookup won't happen
    with pytest.raises(UserNotFoundError):
        await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, 999) # current_user_id is for creator, paid_by_user_id is in schema

@pytest.mark.asyncio
async def test_create_expense_no_list_or_group(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model):
    mock_db_session.get.return_value = basic_user_model # Payer found
    expense_data = expense_create_data_equal_split_group_ctx.model_copy()
    expense_data.list_id = None
    expense_data.group_id = None
    with pytest.raises(InvalidOperationError, match="Expense must be associated with a list or a group"):
        await create_expense(mock_db_session, expense_data, basic_user_model.id)

# --- get_expense_by_id Tests ---
@pytest.mark.asyncio
async def test_get_expense_by_id_found(mock_db_session, db_expense_model):
    mock_db_session.get.return_value = db_expense_model
    expense = await get_expense_by_id(mock_db_session, db_expense_model.id)
    assert expense is not None
    assert expense.id == db_expense_model.id
    mock_db_session.get.assert_called_once_with(ExpenseModel, db_expense_model.id, options=[
        MagicMock(), MagicMock(), MagicMock()
    ]) # Adjust based on actual options used in get_expense_by_id

@pytest.mark.asyncio
async def test_get_expense_by_id_not_found(mock_db_session):
    mock_db_session.get.return_value = None
    expense = await get_expense_by_id(mock_db_session, 999)
    assert expense is None

# --- get_expenses_for_list Tests ---
@pytest.mark.asyncio
async def test_get_expenses_for_list_success(mock_db_session, db_expense_model, basic_list_model):
    mock_result = AsyncMock()
    mock_result.scalars.return_value.all.return_value = [db_expense_model]
    mock_db_session.execute.return_value = mock_result
    
    expenses = await get_expenses_for_list(mock_db_session, basic_list_model.id)
    assert len(expenses) == 1
    assert expenses[0].id == db_expense_model.id
    mock_db_session.execute.assert_called_once()

# --- get_expenses_for_group Tests ---
@pytest.mark.asyncio
async def test_get_expenses_for_group_success(mock_db_session, db_expense_model, basic_group_model):
    mock_result = AsyncMock()
    mock_result.scalars.return_value.all.return_value = [db_expense_model]
    mock_db_session.execute.return_value = mock_result

    expenses = await get_expenses_for_group(mock_db_session, basic_group_model.id)
    assert len(expenses) == 1
    assert expenses[0].id == db_expense_model.id
    mock_db_session.execute.assert_called_once()

# --- update_expense Tests ---
@pytest.mark.asyncio
async def test_update_expense_success(mock_db_session, db_expense_model, expense_update_data, basic_user_model):
    expense_update_data.version = db_expense_model.version # Match version

    # Simulate that the db_expense_model is returned by session.get
    mock_db_session.get.return_value = db_expense_model
    
    updated_expense = await update_expense(mock_db_session, db_expense_model.id, expense_update_data, basic_user_model.id)

    mock_db_session.add.assert_called_once_with(db_expense_model)
    mock_db_session.flush.assert_called_once()
    mock_db_session.refresh.assert_called_once_with(db_expense_model)
    assert updated_expense.description == expense_update_data.description
    assert updated_expense.total_amount == expense_update_data.total_amount
    assert updated_expense.version == db_expense_model.version # Version incremented by the function

@pytest.mark.asyncio
async def test_update_expense_not_found(mock_db_session, expense_update_data, basic_user_model):
    mock_db_session.get.return_value = None # Expense not found
    with pytest.raises(ExpenseNotFoundError):
        await update_expense(mock_db_session, 999, expense_update_data, basic_user_model.id)

@pytest.mark.asyncio
async def test_update_expense_version_conflict(mock_db_session, db_expense_model, expense_update_data, basic_user_model):
    expense_update_data.version = db_expense_model.version + 1 # Create version mismatch
    mock_db_session.get.return_value = db_expense_model
    with pytest.raises(ConflictError):
        await update_expense(mock_db_session, db_expense_model.id, expense_update_data, basic_user_model.id)
    mock_db_session.rollback.assert_called_once()

# --- delete_expense Tests ---
@pytest.mark.asyncio
async def test_delete_expense_success(mock_db_session, db_expense_model, basic_user_model):
    mock_db_session.get.return_value = db_expense_model # Simulate expense found
    
    await delete_expense(mock_db_session, db_expense_model.id, basic_user_model.id)
    
    mock_db_session.delete.assert_called_once_with(db_expense_model)
    # Assuming delete_expense uses session.begin() and commits
    mock_db_session.begin().commit.assert_called_once() 

@pytest.mark.asyncio
async def test_delete_expense_not_found(mock_db_session, basic_user_model):
    mock_db_session.get.return_value = None # Expense not found
    with pytest.raises(ExpenseNotFoundError):
        await delete_expense(mock_db_session, 999, basic_user_model.id)
    mock_db_session.rollback.assert_not_called() # Rollback might be called by begin() context manager exit

@pytest.mark.asyncio
async def test_delete_expense_db_error(mock_db_session, db_expense_model, basic_user_model):
    mock_db_session.get.return_value = db_expense_model
    mock_db_session.delete.side_effect = OperationalError("mock op error", "params", "orig")
    with pytest.raises(DatabaseTransactionError):
        await delete_expense(mock_db_session, db_expense_model.id, basic_user_model.id)
    mock_db_session.begin().rollback.assert_called_once() # Rollback from the transaction context