Refactor database session management across multiple API endpoints to utilize a transactional session, enhancing consistency in transaction handling. Update dependencies in costs, financials, groups, health, invites, items, and lists modules for improved error handling and reliability.

This commit is contained in:
mohamad 2025-05-20 01:19:06 +02:00
parent 98b2f907de
commit 323ce210ce
7 changed files with 51 additions and 58 deletions

View File

@ -7,7 +7,7 @@ from sqlalchemy.orm import Session, selectinload
from decimal import Decimal, ROUND_HALF_UP, ROUND_DOWN
from typing import List
from app.database import get_db
from app.database import get_transactional_session
from app.auth import current_active_user
from app.models import (
User as UserModel,
@ -120,7 +120,7 @@ def calculate_suggested_settlements(user_balances: List[UserBalanceDetail]) -> L
)
async def get_list_cost_summary(
list_id: int,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
@ -282,7 +282,7 @@ async def get_list_cost_summary(
)
async def get_group_balance_summary(
group_id: int,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""

View File

@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from typing import List as PyList, Optional, Sequence
from app.database import get_db
from app.database import get_transactional_session
from app.auth import current_active_user
from app.models import User as UserModel, Group as GroupModel, List as ListModel, UserGroup as UserGroupModel, UserRoleEnum
from app.schemas.expense import (
@ -46,7 +46,7 @@ async def check_list_access_for_financials(db: AsyncSession, list_id: int, user_
)
async def create_new_expense(
expense_in: ExpenseCreate,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} creating expense: {expense_in.description}")
@ -109,7 +109,7 @@ async def create_new_expense(
@router.get("/expenses/{expense_id}", response_model=ExpensePublic, summary="Get Expense by ID", tags=["Expenses"])
async def get_expense(
expense_id: int,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} requesting expense ID {expense_id}")
@ -130,7 +130,7 @@ async def list_list_expenses(
list_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=200),
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} listing expenses for list ID {list_id}")
@ -143,7 +143,7 @@ async def list_group_expenses(
group_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=200),
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} listing expenses for group ID {group_id}")
@ -155,7 +155,7 @@ async def list_group_expenses(
async def update_expense_details(
expense_id: int,
expense_in: ExpenseUpdate,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
@ -209,7 +209,7 @@ async def update_expense_details(
async def delete_expense_record(
expense_id: int,
expected_version: Optional[int] = Query(None, description="Expected version for optimistic locking"),
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
@ -273,7 +273,7 @@ async def delete_expense_record(
)
async def create_new_settlement(
settlement_in: SettlementCreate,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} recording settlement in group {settlement_in.group_id}")
@ -299,7 +299,7 @@ async def create_new_settlement(
@router.get("/settlements/{settlement_id}", response_model=SettlementPublic, summary="Get Settlement by ID", tags=["Settlements"])
async def get_settlement(
settlement_id: int,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} requesting settlement ID {settlement_id}")
@ -321,7 +321,7 @@ async def list_group_settlements(
group_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=200),
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
logger.info(f"User {current_user.email} listing settlements for group ID {group_id}")
@ -333,7 +333,7 @@ async def list_group_settlements(
async def update_settlement_details(
settlement_id: int,
settlement_in: SettlementUpdate,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
@ -387,7 +387,7 @@ async def update_settlement_details(
async def delete_settlement_record(
settlement_id: int,
expected_version: Optional[int] = Query(None, description="Expected version for optimistic locking"),
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""

View File

@ -5,13 +5,13 @@ from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.database import get_transactional_session
from app.auth import current_active_user
from app.models import User as UserModel, UserRoleEnum # Import model and enum
from app.schemas.group import GroupCreate, GroupPublic
from app.schemas.invite import InviteCodePublic
from app.schemas.message import Message # For simple responses
from app.schemas.list import ListPublic
from app.schemas.list import ListPublic, ListDetail
from app.crud import group as crud_group
from app.crud import invite as crud_invite
from app.crud import list as crud_list
@ -36,7 +36,7 @@ router = APIRouter()
)
async def create_group(
group_in: GroupCreate,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Creates a new group, adding the creator as the owner."""
@ -54,7 +54,7 @@ async def create_group(
tags=["Groups"]
)
async def read_user_groups(
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Retrieves all groups the current user is a member of."""
@ -71,7 +71,7 @@ async def read_user_groups(
)
async def read_group(
group_id: int,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Retrieves details for a specific group, including members, if the user is part of it."""
@ -98,7 +98,7 @@ async def read_group(
)
async def create_group_invite(
group_id: int,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Generates a new invite code for the group. Requires owner/admin role (MVP: owner only)."""
@ -121,15 +121,7 @@ async def create_group_invite(
# This case should ideally be covered by exceptions from create_invite now
raise InviteCreationError(group_id)
try:
await db.commit() # Explicit commit before returning
logger.info(f"User {current_user.email} created and committed invite code for group {group_id}")
except Exception as e:
logger.error(f"Failed to commit transaction after creating invite for group {group_id}: {e}", exc_info=True)
await db.rollback() # Ensure rollback if explicit commit fails
# Re-raise to ensure a 500 error is returned
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to save invite: {str(e)}")
logger.info(f"User {current_user.email} created invite code for group {group_id}")
return invite
@router.get(
@ -140,7 +132,7 @@ async def create_group_invite(
)
async def get_group_active_invite(
group_id: int,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Retrieves the active invite code for the group. Requires group membership (owner/admin to be stricter later if needed)."""
@ -177,7 +169,7 @@ async def get_group_active_invite(
)
async def leave_group(
group_id: int,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Removes the current user from the specified group."""
@ -216,7 +208,7 @@ async def leave_group(
async def remove_group_member(
group_id: int,
user_id_to_remove: int,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Removes a specified user from the group. Requires current user to be owner."""
@ -249,13 +241,13 @@ async def remove_group_member(
@router.get(
"/{group_id}/lists",
response_model=List[ListPublic],
response_model=List[ListDetail],
summary="Get Group Lists",
tags=["Groups", "Lists"]
)
async def read_group_lists(
group_id: int,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Retrieves all lists belonging to a specific group, if the user is a member."""

View File

@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import text
from app.database import get_async_session
from app.database import get_transactional_session
from app.schemas.health import HealthStatus
from app.core.exceptions import DatabaseConnectionError
@ -18,7 +18,7 @@ router = APIRouter()
description="Checks the operational status of the API and its connection to the database.",
tags=["Health"]
)
async def check_health(db: AsyncSession = Depends(get_async_session)):
async def check_health(db: AsyncSession = Depends(get_transactional_session)):
"""
Health check endpoint. Verifies API reachability and database connection.
"""

View File

@ -3,7 +3,7 @@ import logging
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.database import get_transactional_session
from app.auth import current_active_user
from app.models import User as UserModel, UserRoleEnum
from app.schemas.invite import InviteAccept
@ -30,7 +30,7 @@ router = APIRouter()
)
async def accept_invite(
invite_in: InviteAccept,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Accepts a group invite using the provided invite code."""

View File

@ -5,7 +5,7 @@ from typing import List as PyList, Optional
from fastapi import APIRouter, Depends, HTTPException, status, Response, Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.database import get_transactional_session
from app.auth import current_active_user
# --- Import Models Correctly ---
from app.models import User as UserModel
@ -23,7 +23,7 @@ router = APIRouter()
# Now ItemModel is defined before being used as a type hint
async def get_item_and_verify_access(
item_id: int,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user)
) -> ItemModel:
"""Dependency to get an item and verify the user has access to its list."""
@ -52,7 +52,7 @@ async def get_item_and_verify_access(
async def create_list_item(
list_id: int,
item_in: ItemCreate,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""Adds a new item to a specific list. User must have access to the list."""
@ -80,7 +80,7 @@ async def create_list_item(
)
async def read_list_items(
list_id: int,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
# Add sorting/filtering params later if needed: sort_by: str = 'created_at', order: str = 'asc'
):
@ -99,7 +99,7 @@ async def read_list_items(
@router.put(
"/items/{item_id}", # Operate directly on item ID
"/lists/{list_id}/items/{item_id}", # Nested under lists
response_model=ItemPublic,
summary="Update Item",
tags=["Items"],
@ -108,10 +108,11 @@ async def read_list_items(
}
)
async def update_item(
item_id: int, # Item ID from path
list_id: int,
item_id: int,
item_in: ItemUpdate,
item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user), # Need user ID for completed_by
):
"""
@ -140,7 +141,7 @@ async def update_item(
@router.delete(
"/items/{item_id}", # Operate directly on item ID
"/lists/{list_id}/items/{item_id}", # Nested under lists
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete Item",
tags=["Items"],
@ -149,10 +150,11 @@ async def update_item(
}
)
async def delete_item(
item_id: int, # Item ID from path
list_id: int,
item_id: int,
expected_version: Optional[int] = Query(None, description="The expected version of the item to delete for optimistic locking."),
item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user), # Log who deleted it
):
"""

View File

@ -5,7 +5,7 @@ from typing import List as PyList, Optional # Alias for Python List type hint
from fastapi import APIRouter, Depends, HTTPException, status, Response, Query # Added Query
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.database import get_transactional_session
from app.auth import current_active_user
from app.models import User as UserModel
from app.schemas.list import ListCreate, ListUpdate, ListPublic, ListDetail
@ -40,7 +40,7 @@ router = APIRouter()
)
async def create_list(
list_in: ListCreate,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
@ -61,7 +61,6 @@ async def create_list(
try:
created_list = await crud_list.create_list(db=db, list_in=list_in, creator_id=current_user.id)
await db.commit() # Ensure the transaction is committed
logger.info(f"List '{created_list.name}' (ID: {created_list.id}) created successfully for user {current_user.email}.")
return created_list
except DatabaseIntegrityError as e:
@ -87,12 +86,12 @@ async def create_list(
@router.get(
"", # Route relative to prefix "/lists"
response_model=PyList[ListPublic], # Return a list of basic list info
response_model=PyList[ListDetail], # Return a list of detailed list info including items
summary="List Accessible Lists",
tags=["Lists"]
)
async def read_lists(
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
# Add pagination parameters later if needed: skip: int = 0, limit: int = 100
):
@ -114,7 +113,7 @@ async def read_lists(
)
async def read_list(
list_id: int,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
@ -139,7 +138,7 @@ async def read_list(
async def update_list(
list_id: int,
list_in: ListUpdate,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
@ -177,7 +176,7 @@ async def update_list(
async def delete_list(
list_id: int,
expected_version: Optional[int] = Query(None, description="The expected version of the list to delete for optimistic locking."),
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""
@ -212,7 +211,7 @@ async def delete_list(
)
async def read_list_status(
list_id: int,
db: AsyncSession = Depends(get_db),
db: AsyncSession = Depends(get_transactional_session),
current_user: UserModel = Depends(current_active_user),
):
"""