import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.exc import IntegrityError, OperationalError
from decimal import Decimal, ROUND_HALF_UP
from datetime import datetime, timezone
from typing import List as PyList, Optional

from app.crud.expense import (
    create_expense,
    get_expense_by_id,
    get_expenses_for_list,
    get_expenses_for_group,
    update_expense, # Assuming update_expense exists
    delete_expense,  # Assuming delete_expense exists
    get_users_for_splitting # Helper, might test indirectly
)
from app.schemas.expense import ExpenseCreate, ExpenseUpdate, ExpenseSplitCreate
from app.models import (
    Expense as ExpenseModel,
    ExpenseSplit as ExpenseSplitModel,
    User as UserModel,
    List as ListModel,
    Group as GroupModel,
    UserGroup as UserGroupModel,
    Item as ItemModel,
    SplitTypeEnum
)
from app.core.exceptions import (
    ListNotFoundError,
    GroupNotFoundError,
    UserNotFoundError,
    InvalidOperationError
)

# General Fixtures
@pytest.fixture
def mock_db_session():
    session = AsyncMock()
    session.begin = AsyncMock()
    session.begin_nested = AsyncMock()
    session.commit = AsyncMock()
    session.rollback = AsyncMock()
    session.refresh = AsyncMock()
    session.add = MagicMock()
    session.delete = MagicMock()
    session.execute = AsyncMock()
    session.get = AsyncMock()
    session.flush = AsyncMock()
    session.in_transaction = MagicMock(return_value=False)
    return session

@pytest.fixture
def basic_user_model():
    return UserModel(id=1, name="Test User", email="test@example.com")

@pytest.fixture
def another_user_model():
    return UserModel(id=2, name="Another User", email="another@example.com")

@pytest.fixture
def basic_group_model():
    group = GroupModel(id=1, name="Test Group")
    # Simulate member_associations for get_users_for_splitting if needed directly
    # group.member_associations = [UserGroupModel(user_id=1, group_id=1, user=basic_user_model()), UserGroupModel(user_id=2, group_id=1, user=another_user_model())]
    return group

@pytest.fixture
def basic_list_model(basic_group_model, basic_user_model):
    return ListModel(id=1, name="Test List", group_id=basic_group_model.id, group=basic_group_model, creator_id=basic_user_model.id, creator=basic_user_model)

@pytest.fixture
def expense_create_data_equal_split_list_ctx(basic_list_model, basic_user_model):
    return ExpenseCreate(
        description="Grocery run",
        total_amount=Decimal("30.00"),
        currency="USD",
        expense_date=datetime.now(timezone.utc),
        split_type=SplitTypeEnum.EQUAL,
        list_id=basic_list_model.id,
        group_id=None, # Derived from list
        item_id=None,
        paid_by_user_id=basic_user_model.id,
        splits_in=None
    )

@pytest.fixture
def expense_create_data_equal_split_group_ctx(basic_group_model, basic_user_model):
    return ExpenseCreate(
        description="Movies",
        total_amount=Decimal("50.00"),
        currency="USD",
        expense_date=datetime.now(timezone.utc),
        split_type=SplitTypeEnum.EQUAL,
        list_id=None,
        group_id=basic_group_model.id,
        item_id=None,
        paid_by_user_id=basic_user_model.id,
        splits_in=None
    )

@pytest.fixture
def expense_create_data_exact_split(basic_group_model, basic_user_model, another_user_model):
    return ExpenseCreate(
        description="Dinner",
        total_amount=Decimal("100.00"),
        split_type=SplitTypeEnum.EXACT_AMOUNTS,
        group_id=basic_group_model.id,
        paid_by_user_id=basic_user_model.id,
        splits_in=[
            ExpenseSplitCreate(user_id=basic_user_model.id, owed_amount=Decimal("60.00")),
            ExpenseSplitCreate(user_id=another_user_model.id, owed_amount=Decimal("40.00")),
        ]
    )

@pytest.fixture
def db_expense_model(expense_create_data_equal_split_group_ctx, basic_user_model):
    return ExpenseModel(
        id=1,
        description=expense_create_data_equal_split_group_ctx.description,
        total_amount=expense_create_data_equal_split_group_ctx.total_amount,
        currency=expense_create_data_equal_split_group_ctx.currency,
        expense_date=expense_create_data_equal_split_group_ctx.expense_date,
        split_type=expense_create_data_equal_split_group_ctx.split_type,
        list_id=expense_create_data_equal_split_group_ctx.list_id,
        group_id=expense_create_data_equal_split_group_ctx.group_id,
        item_id=expense_create_data_equal_split_group_ctx.item_id,
        paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id,
        created_by_user_id=basic_user_model.id,
        paid_by=basic_user_model, # Assuming paid_by relation is loaded
        created_by_user=basic_user_model, # Assuming created_by_user relation is loaded
        # splits would be populated after creation usually
        version=1
    )

# Tests for get_users_for_splitting (indirectly tested via create_expense, but stubs for direct if needed)
@pytest.mark.asyncio
async def test_get_users_for_splitting_group_context(mock_db_session, basic_group_model, basic_user_model, another_user_model):
    # Setup group with members
    user_group_assoc1 = UserGroupModel(user=basic_user_model, user_id=basic_user_model.id)
    user_group_assoc2 = UserGroupModel(user=another_user_model, user_id=another_user_model.id)
    basic_group_model.member_associations = [user_group_assoc1, user_group_assoc2]
    
    mock_execute = AsyncMock()
    mock_execute.scalars.return_value.first.return_value = basic_group_model
    mock_db_session.execute.return_value = mock_execute

    users = await get_users_for_splitting(mock_db_session, expense_group_id=1, expense_list_id=None, expense_paid_by_user_id=1)
    assert len(users) == 2
    assert basic_user_model in users
    assert another_user_model in users

# --- create_expense Tests ---
@pytest.mark.asyncio
async def test_create_expense_equal_split_group_success(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model, basic_group_model, another_user_model):
    mock_db_session.get.side_effect = [basic_user_model, basic_group_model]

    mock_result = AsyncMock()
    mock_result.scalar_one_or_none.return_value = ExpenseModel(
        id=1,
        description=expense_create_data_equal_split_group_ctx.description,
        total_amount=expense_create_data_equal_split_group_ctx.total_amount,
        currency=expense_create_data_equal_split_group_ctx.currency,
        expense_date=expense_create_data_equal_split_group_ctx.expense_date,
        split_type=expense_create_data_equal_split_group_ctx.split_type,
        list_id=expense_create_data_equal_split_group_ctx.list_id,
        group_id=expense_create_data_equal_split_group_ctx.group_id,
        item_id=expense_create_data_equal_split_group_ctx.item_id,
        paid_by_user_id=expense_create_data_equal_split_group_ctx.paid_by_user_id,
        created_by_user_id=basic_user_model.id,
        version=1
    )
    mock_db_session.execute.return_value = mock_result

    with patch('app.crud.expense.get_users_for_splitting', new_callable=AsyncMock) as mock_get_users:
        mock_get_users.return_value = [basic_user_model, another_user_model]
        created_expense = await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, current_user_id=1)

        mock_db_session.add.assert_called()
        mock_db_session.flush.assert_called_once()
        assert created_expense is not None
        assert created_expense.total_amount == expense_create_data_equal_split_group_ctx.total_amount
        assert created_expense.split_type == SplitTypeEnum.EQUAL
        assert len(created_expense.splits) == 2

        expected_amount_per_user = (expense_create_data_equal_split_group_ctx.total_amount / 2).quantize(Decimal("0.01"), rounding=ROUND_HALF_UP)
        for split in created_expense.splits:
            assert split.owed_amount == expected_amount_per_user

@pytest.mark.asyncio
async def test_create_expense_exact_split_success(mock_db_session, expense_create_data_exact_split, basic_user_model, basic_group_model, another_user_model):
    mock_db_session.get.side_effect = [basic_user_model, basic_group_model]

    mock_result = AsyncMock()
    mock_result.scalar_one_or_none.return_value = ExpenseModel(
        id=1,
        description=expense_create_data_exact_split.description,
        total_amount=expense_create_data_exact_split.total_amount,
        currency="USD",
        expense_date=expense_create_data_exact_split.expense_date,
        split_type=expense_create_data_exact_split.split_type,
        list_id=expense_create_data_exact_split.list_id,
        group_id=expense_create_data_exact_split.group_id,
        item_id=expense_create_data_exact_split.item_id,
        paid_by_user_id=expense_create_data_exact_split.paid_by_user_id,
        created_by_user_id=basic_user_model.id,
        version=1
    )
    mock_db_session.execute.return_value = mock_result

    created_expense = await create_expense(mock_db_session, expense_create_data_exact_split, current_user_id=1)

    mock_db_session.add.assert_called()
    mock_db_session.flush.assert_called_once()
    assert created_expense is not None
    assert created_expense.split_type == SplitTypeEnum.EXACT_AMOUNTS
    assert len(created_expense.splits) == 2

@pytest.mark.asyncio
async def test_create_expense_payer_not_found(mock_db_session, expense_create_data_equal_split_group_ctx):
    mock_db_session.get.return_value = None # Payer not found
    with pytest.raises(UserNotFoundError):
        await create_expense(mock_db_session, expense_create_data_equal_split_group_ctx, 1)

@pytest.mark.asyncio
async def test_create_expense_no_list_or_group(mock_db_session, expense_create_data_equal_split_group_ctx, basic_user_model):
    mock_db_session.get.return_value = basic_user_model # Payer found
    expense_data = expense_create_data_equal_split_group_ctx.model_copy()
    expense_data.list_id = None
    expense_data.group_id = None
    with pytest.raises(InvalidOperationError, match="Expense must be associated with a list or a group"):
        await create_expense(mock_db_session, expense_data, 1)

# --- get_expense_by_id Tests ---
@pytest.mark.asyncio
async def test_get_expense_by_id_found(mock_db_session, db_expense_model):
    mock_result = AsyncMock()
    mock_result.scalars.return_value.first.return_value = db_expense_model
    mock_db_session.execute.return_value = mock_result

    expense = await get_expense_by_id(mock_db_session, 1)
    assert expense is not None
    assert expense.id == 1
    mock_db_session.execute.assert_called_once()

@pytest.mark.asyncio
async def test_get_expense_by_id_not_found(mock_db_session):
    mock_result = AsyncMock()
    mock_result.scalars.return_value.first.return_value = None
    mock_db_session.execute.return_value = mock_result

    expense = await get_expense_by_id(mock_db_session, 999)
    assert expense is None
    mock_db_session.execute.assert_called_once()

# --- get_expenses_for_list Tests ---
@pytest.mark.asyncio
async def test_get_expenses_for_list_success(mock_db_session, db_expense_model):
    mock_result = AsyncMock()
    mock_result.scalars.return_value.all.return_value = [db_expense_model]
    mock_db_session.execute.return_value = mock_result

    expenses = await get_expenses_for_list(mock_db_session, list_id=1)
    assert len(expenses) == 1
    assert expenses[0].list_id == 1
    mock_db_session.execute.assert_called_once()

# --- get_expenses_for_group Tests ---
@pytest.mark.asyncio
async def test_get_expenses_for_group_success(mock_db_session, db_expense_model):
    mock_result = AsyncMock()
    mock_result.scalars.return_value.all.return_value = [db_expense_model]
    mock_db_session.execute.return_value = mock_result

    expenses = await get_expenses_for_group(mock_db_session, group_id=1)
    assert len(expenses) == 1
    assert expenses[0].group_id == 1
    mock_db_session.execute.assert_called_once()

# --- Stubs for update_expense and delete_expense ---
# These will need more details once the actual implementation of update/delete is clear
# For example, how splits are handled on update, versioning, etc.

@pytest.mark.asyncio
async def test_update_expense_stub(mock_db_session):
    # Placeholder: Test logic for update_expense will be more complex
    # Needs ExpenseUpdate schema, existing expense object, and mocking of commit/refresh
    # Also depends on what fields are updatable and how splits are managed.
    expense_to_update = MagicMock(spec=ExpenseModel)
    expense_to_update.version = 1
    update_payload = ExpenseUpdate(description="New description", version=1) # Add other fields as per schema definition
    
    # Simulate the update_expense function behavior
    # For example, if it loads the expense, modifies, commits, refreshes:
    # mock_db_session.get.return_value = expense_to_update
    # updated_expense = await update_expense(mock_db_session, expense_to_update, update_payload)
    # assert updated_expense.description == "New description"
    # mock_db_session.commit.assert_called_once()
    # mock_db_session.refresh.assert_called_once()
    pass # Replace with actual test logic

@pytest.mark.asyncio
async def test_delete_expense_stub(mock_db_session):
    # Placeholder: Test logic for delete_expense
    # Needs an existing expense object and mocking of delete/commit
    # Also, consider implications (e.g., are splits deleted?)
    expense_to_delete = MagicMock(spec=ExpenseModel)
    expense_to_delete.id = 1
    expense_to_delete.version = 1

    # Simulate delete_expense behavior
    # mock_db_session.get.return_value = expense_to_delete # If it re-fetches
    # await delete_expense(mock_db_session, expense_to_delete, expected_version=1)
    # mock_db_session.delete.assert_called_once_with(expense_to_delete)
    # mock_db_session.commit.assert_called_once()
    pass # Replace with actual test logic

# TODO: Add more tests for create_expense covering:
# - List context success
# - Percentage, Shares, Item-based splits
# - Error cases for each split type (e.g., total mismatch, invalid inputs)
# - Validation of list_id/group_id consistency
# - User not found in splits_in
# - Item not found for ITEM_BASED split

# TODO: Flesh out update_expense tests:
# - Success case
# - Version mismatch
# - Trying to update immutable fields
# - How splits are handled (recalculated, deleted/recreated, or not changeable)

# TODO: Flesh out delete_expense tests:
# - Success case
# - Version mismatch (if applicable)
# - Ensure associated splits are also deleted (cascade behavior)