# app/crud/list.py
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy import or_, and_, delete as sql_delete, func as sql_func, desc
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
from typing import Optional, List as PyList
import logging # Add logging import

from app.schemas.list import ListStatus
from app.models import List as ListModel, UserGroup as UserGroupModel, Item as ItemModel
from app.schemas.list import ListCreate, ListUpdate
from app.core.exceptions import (
    ListNotFoundError,
    ListPermissionError,
    ListCreatorRequiredError,
    DatabaseConnectionError,
    DatabaseIntegrityError,
    DatabaseQueryError,
    DatabaseTransactionError,
    ConflictError,
    ListOperationError
)

logger = logging.getLogger(__name__) # Initialize logger

async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel:
    """Creates a new list record."""
    try:
        async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
            db_list = ListModel(
                name=list_in.name,
                description=list_in.description,
                group_id=list_in.group_id,
                created_by_id=creator_id,
                is_complete=False
            )
            db.add(db_list)
            await db.flush() # Assigns ID

            # Re-fetch with relationships for the response
            stmt = (
                select(ListModel)
                .where(ListModel.id == db_list.id)
                .options(
                    selectinload(ListModel.creator),
                    selectinload(ListModel.group)
                    # selectinload(ListModel.items) # Optionally add if items are always needed in response
                )
            )
            result = await db.execute(stmt)
            loaded_list = result.scalar_one_or_none()

            if loaded_list is None:
                raise ListOperationError("Failed to load list after creation.")

            return loaded_list
    except IntegrityError as e:
        logger.error(f"Database integrity error during list creation: {str(e)}", exc_info=True)
        raise DatabaseIntegrityError(f"Failed to create list: {str(e)}")
    except OperationalError as e:
        logger.error(f"Database connection error during list creation: {str(e)}", exc_info=True)
        raise DatabaseConnectionError(f"Database connection error: {str(e)}")
    except SQLAlchemyError as e:
        logger.error(f"Unexpected SQLAlchemy error during list creation: {str(e)}", exc_info=True)
        raise DatabaseTransactionError(f"Failed to create list: {str(e)}")

async def get_lists_for_user(db: AsyncSession, user_id: int) -> PyList[ListModel]:
    """Gets all lists accessible by a user."""
    try:
        group_ids_result = await db.execute(
            select(UserGroupModel.group_id).where(UserGroupModel.user_id == user_id)
        )
        user_group_ids = group_ids_result.scalars().all()

        conditions = [
            and_(ListModel.created_by_id == user_id, ListModel.group_id.is_(None))
        ]
        if user_group_ids:
            conditions.append(ListModel.group_id.in_(user_group_ids))

        query = (
            select(ListModel)
            .where(or_(*conditions))
            .options(
                selectinload(ListModel.creator),
                selectinload(ListModel.group),
                selectinload(ListModel.items).options(
                    joinedload(ItemModel.added_by_user),
                    joinedload(ItemModel.completed_by_user)
                )
            )
            .order_by(ListModel.updated_at.desc())
        )

        result = await db.execute(query)
        return result.scalars().all()
    except OperationalError as e:
        raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
    except SQLAlchemyError as e:
        raise DatabaseQueryError(f"Failed to query user lists: {str(e)}")

async def get_list_by_id(db: AsyncSession, list_id: int, load_items: bool = False) -> Optional[ListModel]:
    """Gets a single list by ID, optionally loading its items."""
    try:
        query = (
            select(ListModel)
            .where(ListModel.id == list_id)
            .options(
                selectinload(ListModel.creator),
                selectinload(ListModel.group)
            )
        )
        if load_items:
            query = query.options(
                selectinload(ListModel.items).options(
                    joinedload(ItemModel.added_by_user),
                    joinedload(ItemModel.completed_by_user)
                )
            )
        result = await db.execute(query)
        return result.scalars().first()
    except OperationalError as e:
        raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
    except SQLAlchemyError as e:
        raise DatabaseQueryError(f"Failed to query list: {str(e)}")

async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) -> ListModel:
    """Updates an existing list record, checking for version conflicts."""
    try:
        async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
            if list_db.version != list_in.version: # list_db here is the one passed in, pre-loaded by API layer
                raise ConflictError(
                    f"List '{list_db.name}' (ID: {list_db.id}) has been modified. "
                    f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh."
                )

            update_data = list_in.model_dump(exclude_unset=True, exclude={'version'})

            for key, value in update_data.items():
                setattr(list_db, key, value)
            
            list_db.version += 1

            db.add(list_db) # Add the already attached list_db to mark it dirty for the session
            await db.flush()

            # Re-fetch with relationships for the response
            stmt = (
                select(ListModel)
                .where(ListModel.id == list_db.id)
                .options(
                    selectinload(ListModel.creator),
                    selectinload(ListModel.group)
                    # selectinload(ListModel.items) # Optionally add if items are always needed in response
                )
            )
            result = await db.execute(stmt)
            updated_list = result.scalar_one_or_none()

            if updated_list is None: # Should not happen
                raise ListOperationError("Failed to load list after update.")
            
            return updated_list
    except IntegrityError as e:
        logger.error(f"Database integrity error during list update: {str(e)}", exc_info=True)
        raise DatabaseIntegrityError(f"Failed to update list due to integrity constraint: {str(e)}")
    except OperationalError as e:
        logger.error(f"Database connection error while updating list: {str(e)}", exc_info=True)
        raise DatabaseConnectionError(f"Database connection error while updating list: {str(e)}")
    except ConflictError:
        raise
    except SQLAlchemyError as e:
        logger.error(f"Unexpected SQLAlchemy error during list update: {str(e)}", exc_info=True)
        raise DatabaseTransactionError(f"Failed to update list: {str(e)}")

async def delete_list(db: AsyncSession, list_db: ListModel) -> None:
    """Deletes a list record. Version check should be done by the caller (API endpoint)."""
    try:
        async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Standardize transaction
            await db.delete(list_db)
    except OperationalError as e:
        logger.error(f"Database connection error while deleting list: {str(e)}", exc_info=True)
        raise DatabaseConnectionError(f"Database connection error while deleting list: {str(e)}")
    except SQLAlchemyError as e:
        logger.error(f"Unexpected SQLAlchemy error while deleting list: {str(e)}", exc_info=True)
        raise DatabaseTransactionError(f"Failed to delete list: {str(e)}")

async def check_list_permission(db: AsyncSession, list_id: int, user_id: int, require_creator: bool = False) -> ListModel:
    """Fetches a list and verifies user permission."""
    try:
        list_db = await get_list_by_id(db, list_id=list_id, load_items=True)
        if not list_db:
            raise ListNotFoundError(list_id)

        is_creator = list_db.created_by_id == user_id

        if require_creator:
            if not is_creator:
                raise ListCreatorRequiredError(list_id, "access")
            return list_db

        if is_creator:
            return list_db

        if list_db.group_id:
            from app.crud.group import is_user_member
            is_member = await is_user_member(db, group_id=list_db.group_id, user_id=user_id)
            if not is_member:
                raise ListPermissionError(list_id)
            return list_db
        else:
            raise ListPermissionError(list_id)
    except OperationalError as e:
        raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
    except SQLAlchemyError as e:
        raise DatabaseQueryError(f"Failed to check list permissions: {str(e)}")

async def get_list_status(db: AsyncSession, list_id: int) -> ListStatus:
    """Gets the update timestamps and item count for a list."""
    try:
        list_query = select(ListModel.updated_at).where(ListModel.id == list_id)
        list_result = await db.execute(list_query)
        list_updated_at = list_result.scalar_one_or_none()

        if list_updated_at is None:
            raise ListNotFoundError(list_id)

        item_status_query = (
            select(
                sql_func.max(ItemModel.updated_at).label("latest_item_updated_at"),
                sql_func.count(ItemModel.id).label("item_count")
            )
            .where(ItemModel.list_id == list_id)
        )
        item_result = await db.execute(item_status_query)
        item_status = item_result.first()

        return ListStatus(
            list_updated_at=list_updated_at,
            latest_item_updated_at=item_status.latest_item_updated_at if item_status else None,
            item_count=item_status.item_count if item_status else 0
        )
    except OperationalError as e:
        raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
    except SQLAlchemyError as e:
        raise DatabaseQueryError(f"Failed to get list status: {str(e)}")

async def get_list_by_name_and_group(
    db: AsyncSession, 
    name: str, 
    group_id: Optional[int], 
    user_id: int # user_id is for permission check, not direct list attribute
) -> Optional[ListModel]:
    """
    Gets a list by name and group, ensuring the user has permission to access it.
    Used for conflict resolution when creating lists.
    """
    try:
        # Base query for the list itself
        base_query = select(ListModel).where(ListModel.name == name)
        
        if group_id is not None:
            base_query = base_query.where(ListModel.group_id == group_id)
        else:
            base_query = base_query.where(ListModel.group_id.is_(None))
        
        # Add eager loading for common relationships
        base_query = base_query.options(
            selectinload(ListModel.creator),
            selectinload(ListModel.group)
        )
        
        list_result = await db.execute(base_query)
        target_list = list_result.scalar_one_or_none()

        if not target_list:
            return None

        # Permission check
        is_creator = target_list.created_by_id == user_id
        
        if is_creator:
            return target_list
        
        if target_list.group_id:
            from app.crud.group import is_user_member # Assuming this is a quick check not needing its own transaction
            is_member_of_group = await is_user_member(db, group_id=target_list.group_id, user_id=user_id)
            if is_member_of_group:
                return target_list
        
        # If not creator and (not a group list or not a member of the group list)
        return None
        
    except OperationalError as e:
        raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
    except SQLAlchemyError as e:
        raise DatabaseQueryError(f"Failed to query list by name and group: {str(e)}")