188 lines
8.2 KiB
Python
188 lines
8.2 KiB
Python
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.future import select
|
|
from sqlalchemy.orm import selectinload
|
|
from sqlalchemy import delete as sql_delete, update as sql_update
|
|
from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError
|
|
from typing import Optional, List as PyList
|
|
from datetime import datetime, timezone
|
|
import logging
|
|
from sqlalchemy import func
|
|
|
|
from app.models import Item as ItemModel, User as UserModel
|
|
from app.schemas.item import ItemCreate, ItemUpdate
|
|
from app.core.exceptions import (
|
|
ItemNotFoundError,
|
|
DatabaseConnectionError,
|
|
DatabaseIntegrityError,
|
|
DatabaseQueryError,
|
|
DatabaseTransactionError,
|
|
ConflictError,
|
|
ItemOperationError
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel:
|
|
"""Creates a new item record for a specific list, setting its position."""
|
|
try:
|
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Start transaction
|
|
max_pos_stmt = select(func.max(ItemModel.position)).where(ItemModel.list_id == list_id)
|
|
max_pos_result = await db.execute(max_pos_stmt)
|
|
max_pos = max_pos_result.scalar_one_or_none() or 0
|
|
|
|
db_item = ItemModel(
|
|
name=item_in.name,
|
|
quantity=item_in.quantity,
|
|
list_id=list_id,
|
|
added_by_id=user_id,
|
|
is_complete=False,
|
|
position=max_pos + 1
|
|
)
|
|
db.add(db_item)
|
|
await db.flush()
|
|
|
|
stmt = (
|
|
select(ItemModel)
|
|
.where(ItemModel.id == db_item.id)
|
|
.options(
|
|
selectinload(ItemModel.added_by_user),
|
|
selectinload(ItemModel.completed_by_user)
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
loaded_item = result.scalar_one_or_none()
|
|
|
|
if loaded_item is None:
|
|
raise ItemOperationError("Failed to load item after creation.")
|
|
|
|
return loaded_item
|
|
except IntegrityError as e:
|
|
logger.error(f"Database integrity error during item creation: {str(e)}", exc_info=True)
|
|
raise DatabaseIntegrityError(f"Failed to create item: {str(e)}")
|
|
except OperationalError as e:
|
|
logger.error(f"Database connection error during item creation: {str(e)}", exc_info=True)
|
|
raise DatabaseConnectionError(f"Database connection error: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Unexpected SQLAlchemy error during item creation: {str(e)}", exc_info=True)
|
|
raise DatabaseTransactionError(f"Failed to create item: {str(e)}")
|
|
|
|
async def get_items_by_list_id(db: AsyncSession, list_id: int) -> PyList[ItemModel]:
|
|
"""Gets all items belonging to a specific list, ordered by creation time."""
|
|
try:
|
|
stmt = (
|
|
select(ItemModel)
|
|
.where(ItemModel.list_id == list_id)
|
|
.options(
|
|
selectinload(ItemModel.added_by_user),
|
|
selectinload(ItemModel.completed_by_user)
|
|
)
|
|
.order_by(ItemModel.created_at.asc())
|
|
)
|
|
result = await db.execute(stmt)
|
|
return result.scalars().all()
|
|
except OperationalError as e:
|
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
raise DatabaseQueryError(f"Failed to query items: {str(e)}")
|
|
|
|
async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]:
|
|
"""Gets a single item by its ID."""
|
|
try:
|
|
stmt = (
|
|
select(ItemModel)
|
|
.where(ItemModel.id == item_id)
|
|
.options(
|
|
selectinload(ItemModel.added_by_user),
|
|
selectinload(ItemModel.completed_by_user),
|
|
selectinload(ItemModel.list)
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
return result.scalars().first()
|
|
except OperationalError as e:
|
|
raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
raise DatabaseQueryError(f"Failed to query item: {str(e)}")
|
|
|
|
async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, user_id: int) -> ItemModel:
|
|
"""Updates an existing item record, checking for version conflicts and handling reordering."""
|
|
try:
|
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Start transaction
|
|
if item_db.version != item_in.version:
|
|
raise ConflictError(
|
|
f"Item '{item_db.name}' (ID: {item_db.id}) has been modified. "
|
|
f"Your version is {item_in.version}, current version is {item_db.version}. Please refresh."
|
|
)
|
|
|
|
update_data = item_in.model_dump(exclude_unset=True, exclude={'version'})
|
|
|
|
if 'position' in update_data:
|
|
new_position = update_data.pop('position')
|
|
|
|
list_id = item_db.list_id
|
|
stmt = select(ItemModel).where(ItemModel.list_id == list_id).order_by(ItemModel.position.asc(), ItemModel.created_at.asc())
|
|
result = await db.execute(stmt)
|
|
items_in_list = result.scalars().all()
|
|
|
|
item_to_move = next((it for it in items_in_list if it.id == item_db.id), None)
|
|
if item_to_move:
|
|
items_in_list.remove(item_to_move)
|
|
insert_pos = max(0, min(new_position - 1, len(items_in_list)))
|
|
items_in_list.insert(insert_pos, item_to_move)
|
|
|
|
for i, item in enumerate(items_in_list):
|
|
item.position = i + 1
|
|
|
|
if 'is_complete' in update_data:
|
|
if update_data['is_complete'] is True:
|
|
if item_db.completed_by_id is None:
|
|
update_data['completed_by_id'] = user_id
|
|
else:
|
|
update_data['completed_by_id'] = None
|
|
|
|
for key, value in update_data.items():
|
|
setattr(item_db, key, value)
|
|
|
|
item_db.version += 1
|
|
db.add(item_db)
|
|
await db.flush()
|
|
|
|
stmt = (
|
|
select(ItemModel)
|
|
.where(ItemModel.id == item_db.id)
|
|
.options(
|
|
selectinload(ItemModel.added_by_user),
|
|
selectinload(ItemModel.completed_by_user),
|
|
selectinload(ItemModel.list)
|
|
)
|
|
)
|
|
result = await db.execute(stmt)
|
|
updated_item = result.scalar_one_or_none()
|
|
|
|
if updated_item is None:
|
|
raise ItemOperationError("Failed to load item after update.")
|
|
|
|
return updated_item
|
|
except IntegrityError as e:
|
|
logger.error(f"Database integrity error during item update: {str(e)}", exc_info=True)
|
|
raise DatabaseIntegrityError(f"Failed to update item due to integrity constraint: {str(e)}")
|
|
except OperationalError as e:
|
|
logger.error(f"Database connection error while updating item: {str(e)}", exc_info=True)
|
|
raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}")
|
|
except ConflictError:
|
|
raise
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Unexpected SQLAlchemy error during item update: {str(e)}", exc_info=True)
|
|
raise DatabaseTransactionError(f"Failed to update item: {str(e)}")
|
|
|
|
async def delete_item(db: AsyncSession, item_db: ItemModel) -> None:
|
|
"""Deletes an item record. Version check should be done by the caller (API endpoint)."""
|
|
try:
|
|
async with db.begin_nested() if db.in_transaction() else db.begin() as transaction:
|
|
await db.delete(item_db)
|
|
except OperationalError as e:
|
|
logger.error(f"Database connection error while deleting item: {str(e)}", exc_info=True)
|
|
raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}")
|
|
except SQLAlchemyError as e:
|
|
logger.error(f"Unexpected SQLAlchemy error while deleting item: {str(e)}", exc_info=True)
|
|
raise DatabaseTransactionError(f"Failed to delete item: {str(e)}") |