diff --git a/be/app/api/api_router.py b/be/app/api/api_router.py index 23b7759..7c6e939 100644 --- a/be/app/api/api_router.py +++ b/be/app/api/api_router.py @@ -1,12 +1,5 @@ -# app/api/api_router.py from fastapi import APIRouter - -from app.api.v1.api import api_router_v1 # Import the v1 router +from app.api.v1.api import api_router_v1 api_router = APIRouter() - -# Include versioned routers here, adding the /api prefix -api_router.include_router(api_router_v1, prefix="/v1") # Mounts v1 endpoints under /api/v1/... - -# Add other API versions later -# e.g., api_router.include_router(api_router_v2, prefix="/v2") \ No newline at end of file +api_router.include_router(api_router_v1, prefix="/v1") diff --git a/be/app/api/auth/oauth.py b/be/app/api/auth/oauth.py index 6a1b624..a610429 100644 --- a/be/app/api/auth/oauth.py +++ b/be/app/api/auth/oauth.py @@ -19,30 +19,26 @@ async def google_callback(request: Request, db: AsyncSession = Depends(get_trans token_data = await oauth.google.authorize_access_token(request) user_info = await oauth.google.parse_id_token(request, token_data) - # Check if user exists existing_user = (await db.execute(select(User).where(User.email == user_info['email']))).scalar_one_or_none() user_to_login = existing_user if not existing_user: - # Create new user new_user = User( email=user_info['email'], name=user_info.get('name', user_info.get('email')), - is_verified=True, # Email is verified by Google + is_verified=True, is_active=True ) db.add(new_user) - await db.flush() # Use flush instead of commit since we're in a transaction + await db.flush() user_to_login = new_user - # Generate JWT tokens using the new backend access_strategy = get_jwt_strategy() refresh_strategy = get_refresh_jwt_strategy() access_token = await access_strategy.write_token(user_to_login) refresh_token = await refresh_strategy.write_token(user_to_login) - # Redirect to frontend with tokens redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}&refresh_token={refresh_token}" return RedirectResponse(url=redirect_url) @@ -62,12 +58,10 @@ async def apple_callback(request: Request, db: AsyncSession = Depends(get_transa if 'email' not in user_info: return RedirectResponse(url=f"{settings.FRONTEND_URL}/auth/callback?error=apple_email_missing") - # Check if user exists existing_user = (await db.execute(select(User).where(User.email == user_info['email']))).scalar_one_or_none() user_to_login = existing_user if not existing_user: - # Create new user name_info = user_info.get('name', {}) first_name = name_info.get('firstName', '') last_name = name_info.get('lastName', '') @@ -76,21 +70,19 @@ async def apple_callback(request: Request, db: AsyncSession = Depends(get_transa new_user = User( email=user_info['email'], name=full_name, - is_verified=True, # Email is verified by Apple + is_verified=True, is_active=True ) db.add(new_user) - await db.flush() # Use flush instead of commit since we're in a transaction + await db.flush() user_to_login = new_user - # Generate JWT tokens using the new backend access_strategy = get_jwt_strategy() refresh_strategy = get_refresh_jwt_strategy() access_token = await access_strategy.write_token(user_to_login) refresh_token = await refresh_strategy.write_token(user_to_login) - # Redirect to frontend with tokens redirect_url = f"{settings.FRONTEND_URL}/auth/callback?access_token={access_token}&refresh_token={refresh_token}" return RedirectResponse(url=redirect_url) @@ -113,7 +105,6 @@ async def refresh_jwt_token(request: Request): access_strategy = get_jwt_strategy() access_token = await access_strategy.write_token(user) - # Optionally, issue a new refresh token (rotation) new_refresh_token = await refresh_strategy.write_token(user) return JSONResponse({ "access_token": access_token, diff --git a/be/app/api/v1/api.py b/be/app/api/v1/api.py index 4d0599e..996f6ad 100644 --- a/be/app/api/v1/api.py +++ b/be/app/api/v1/api.py @@ -23,5 +23,3 @@ api_router_v1.include_router(costs.router, prefix="/costs", tags=["Costs"]) api_router_v1.include_router(financials.router, prefix="/financials", tags=["Financials"]) api_router_v1.include_router(chores.router, prefix="/chores", tags=["Chores"]) api_router_v1.include_router(oauth.router, prefix="/auth", tags=["Auth"]) -# Add other v1 endpoint routers here later -# e.g., api_router_v1.include_router(users.router, prefix="/users", tags=["Users"]) \ No newline at end of file diff --git a/be/app/api/v1/endpoints/chores.py b/be/app/api/v1/endpoints/chores.py index ba72bcd..3e0cd4a 100644 --- a/be/app/api/v1/endpoints/chores.py +++ b/be/app/api/v1/endpoints/chores.py @@ -20,7 +20,6 @@ from app.core.exceptions import ChoreNotFoundError, PermissionDeniedError, Group logger = logging.getLogger(__name__) router = APIRouter() -# Add this new endpoint before the personal chores section @router.get( "/all", response_model=PyList[ChorePublic], @@ -28,13 +27,12 @@ router = APIRouter() tags=["Chores"] ) async def list_all_chores( - db: AsyncSession = Depends(get_session), # Use read-only session for GET + db: AsyncSession = Depends(get_session), current_user: UserModel = Depends(current_active_user), ): """Retrieves all chores (personal and group) for the current user in a single optimized request.""" logger.info(f"User {current_user.email} listing all their chores") - # Use the optimized function that reduces database queries all_chores = await crud_chore.get_all_user_chores(db=db, user_id=current_user.id) return all_chores @@ -135,14 +133,12 @@ async def delete_personal_chore( """Deletes a personal chore for the current user.""" logger.info(f"User {current_user.email} deleting personal chore ID: {chore_id}") try: - # First, verify it's a personal chore belonging to the user chore_to_delete = await crud_chore.get_chore_by_id(db, chore_id) if not chore_to_delete or chore_to_delete.type != ChoreTypeEnum.personal or chore_to_delete.created_by_id != current_user.id: raise ChoreNotFoundError(chore_id=chore_id, detail="Personal chore not found or not owned by user.") success = await crud_chore.delete_chore(db=db, chore_id=chore_id, user_id=current_user.id, group_id=None) if not success: - # This case should be rare if the above check passes and DB is consistent raise ChoreNotFoundError(chore_id=chore_id) return Response(status_code=status.HTTP_204_NO_CONTENT) except ChoreNotFoundError as e: @@ -156,7 +152,6 @@ async def delete_personal_chore( raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=e.detail) # --- Group Chores Endpoints --- -# (These would be similar to what you might have had before, but now explicitly part of this router) @router.post( "/groups/{group_id}/chores", @@ -235,7 +230,6 @@ async def update_group_chore( if chore_in.group_id is not None and chore_in.group_id != group_id: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Chore's group_id if provided must match path group_id ({group_id}).") - # Ensure chore_in has the correct type for the CRUD operation chore_payload = chore_in.model_copy(update={"type": ChoreTypeEnum.group, "group_id": group_id} if chore_in.type is None else {"group_id": group_id}) try: @@ -271,15 +265,12 @@ async def delete_group_chore( """Deletes a chore from a group, ensuring user has permission.""" logger.info(f"User {current_user.email} deleting chore ID {chore_id} from group {group_id}") try: - # Verify chore exists and belongs to the group before attempting deletion via CRUD - # This gives a more precise error if the chore exists but isn't in this group. chore_to_delete = await crud_chore.get_chore_by_id_and_group(db, chore_id, group_id, current_user.id) # checks permission too if not chore_to_delete : # get_chore_by_id_and_group will raise PermissionDeniedError if user not member raise ChoreNotFoundError(chore_id=chore_id, group_id=group_id) success = await crud_chore.delete_chore(db=db, chore_id=chore_id, user_id=current_user.id, group_id=group_id) if not success: - # This case should be rare if the above check passes and DB is consistent raise ChoreNotFoundError(chore_id=chore_id, group_id=group_id) return Response(status_code=status.HTTP_204_NO_CONTENT) except ChoreNotFoundError as e: @@ -331,7 +322,7 @@ async def create_chore_assignment( ) async def list_my_assignments( include_completed: bool = False, - db: AsyncSession = Depends(get_session), # Use read-only session for GET + db: AsyncSession = Depends(get_session), current_user: UserModel = Depends(current_active_user), ): """Retrieves all chore assignments for the current user.""" @@ -350,7 +341,7 @@ async def list_my_assignments( ) async def list_chore_assignments( chore_id: int, - db: AsyncSession = Depends(get_session), # Use read-only session for GET + db: AsyncSession = Depends(get_session), current_user: UserModel = Depends(current_active_user), ): """Retrieves all assignments for a specific chore.""" @@ -471,7 +462,6 @@ async def get_chore_history( current_user: UserModel = Depends(current_active_user), ): """Retrieves the history of a specific chore.""" - # First, check if user has permission to view the chore itself chore = await crud_chore.get_chore_by_id(db, chore_id) if not chore: raise ChoreNotFoundError(chore_id=chore_id) @@ -503,10 +493,9 @@ async def get_chore_assignment_history( if not assignment: raise ChoreNotFoundError(assignment_id=assignment_id) - # Check permission by checking permission on the parent chore chore = await crud_chore.get_chore_by_id(db, assignment.chore_id) if not chore: - raise ChoreNotFoundError(chore_id=assignment.chore_id) # Should not happen if assignment exists + raise ChoreNotFoundError(chore_id=assignment.chore_id) if chore.type == ChoreTypeEnum.personal and chore.created_by_id != current_user.id: raise PermissionDeniedError("You can only view history for assignments of your own personal chores.") diff --git a/be/app/api/v1/endpoints/costs.py b/be/app/api/v1/endpoints/costs.py index 7124715..f7adb00 100644 --- a/be/app/api/v1/endpoints/costs.py +++ b/be/app/api/v1/endpoints/costs.py @@ -1,4 +1,4 @@ -# app/api/v1/endpoints/costs.py + import logging from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select @@ -18,14 +18,14 @@ from app.models import ( UserGroup as UserGroupModel, SplitTypeEnum, ExpenseSplit as ExpenseSplitModel, - Settlement as SettlementModel, - SettlementActivity as SettlementActivityModel # Added + SettlementActivity as SettlementActivityModel, + Settlement as SettlementModel ) from app.schemas.cost import ListCostSummary, GroupBalanceSummary, UserCostShare, UserBalanceDetail, SuggestedSettlement from app.schemas.expense import ExpenseCreate from app.crud import list as crud_list from app.crud import expense as crud_expense -from app.core.exceptions import ListNotFoundError, ListPermissionError, UserNotFoundError, GroupNotFoundError +from app.core.exceptions import ListNotFoundError, ListPermissionError, GroupNotFoundError logger = logging.getLogger(__name__) router = APIRouter() diff --git a/be/app/api/v1/endpoints/groups.py b/be/app/api/v1/endpoints/groups.py index 6e2acbc..17eeab0 100644 --- a/be/app/api/v1/endpoints/groups.py +++ b/be/app/api/v1/endpoints/groups.py @@ -1,4 +1,3 @@ -# app/api/v1/endpoints/groups.py import logging from typing import List @@ -7,11 +6,11 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_transactional_session, get_session from app.auth import current_active_user -from app.models import User as UserModel, UserRoleEnum # Import model and enum +from app.models import User as UserModel, UserRoleEnum from app.schemas.group import GroupCreate, GroupPublic, GroupScheduleGenerateRequest from app.schemas.invite import InviteCodePublic -from app.schemas.message import Message # For simple responses -from app.schemas.list import ListPublic, ListDetail +from app.schemas.message import Message +from app.schemas.list import ListDetail from app.schemas.chore import ChoreHistoryPublic, ChoreAssignmentPublic from app.schemas.user import UserPublic from app.crud import group as crud_group @@ -46,8 +45,6 @@ async def create_group( """Creates a new group, adding the creator as the owner.""" logger.info(f"User {current_user.email} creating group: {group_in.name}") created_group = await crud_group.create_group(db=db, group_in=group_in, creator_id=current_user.id) - # Load members explicitly if needed for the response (optional here) - # created_group = await crud_group.get_group_by_id(db, created_group.id) return created_group @@ -58,7 +55,7 @@ async def create_group( tags=["Groups"] ) async def read_user_groups( - db: AsyncSession = Depends(get_session), # Use read-only session for GET + db: AsyncSession = Depends(get_session), current_user: UserModel = Depends(current_active_user), ): """Retrieves all groups the current user is a member of.""" @@ -75,12 +72,11 @@ async def read_user_groups( ) async def read_group( group_id: int, - db: AsyncSession = Depends(get_session), # Use read-only session for GET + db: AsyncSession = Depends(get_session), current_user: UserModel = Depends(current_active_user), ): """Retrieves details for a specific group, including members, if the user is part of it.""" logger.info(f"User {current_user.email} requesting details for group ID: {group_id}") - # Check if user is a member first is_member = await crud_group.is_user_member(db=db, group_id=group_id, user_id=current_user.id) if not is_member: logger.warning(f"Access denied: User {current_user.email} not member of group {group_id}") @@ -101,13 +97,12 @@ async def read_group( ) async def read_group_members( group_id: int, - db: AsyncSession = Depends(get_session), # Use read-only session for GET + db: AsyncSession = Depends(get_session), current_user: UserModel = Depends(current_active_user), ): """Retrieves all members of a specific group, if the user is part of it.""" logger.info(f"User {current_user.email} requesting members for group ID: {group_id}") - # Check if user is a member first is_member = await crud_group.is_user_member(db=db, group_id=group_id, user_id=current_user.id) if not is_member: logger.warning(f"Access denied: User {current_user.email} not member of group {group_id}") @@ -118,7 +113,6 @@ async def read_group_members( logger.error(f"Group {group_id} requested by member {current_user.email} not found (data inconsistency?)") raise GroupNotFoundError(group_id) - # Extract and return just the user information from member associations return [member_assoc.user for member_assoc in group.member_associations] @router.post( @@ -136,12 +130,10 @@ async def create_group_invite( logger.info(f"User {current_user.email} attempting to create invite for group {group_id}") user_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=current_user.id) - # --- Permission Check (MVP: Owner only) --- if user_role != UserRoleEnum.owner: logger.warning(f"Permission denied: User {current_user.email} (role: {user_role}) cannot create invite for group {group_id}") raise GroupPermissionError(group_id, "create invites") - # Check if group exists (implicitly done by role check, but good practice) group = await crud_group.get_group_by_id(db, group_id) if not group: raise GroupNotFoundError(group_id) @@ -149,7 +141,6 @@ async def create_group_invite( invite = await crud_invite.create_invite(db=db, group_id=group_id, creator_id=current_user.id) if not invite: logger.error(f"Failed to generate unique invite code for group {group_id}") - # This case should ideally be covered by exceptions from create_invite now raise InviteCreationError(group_id) logger.info(f"User {current_user.email} created invite code for group {group_id}") @@ -163,26 +154,20 @@ async def create_group_invite( ) async def get_group_active_invite( group_id: int, - db: AsyncSession = Depends(get_session), # Use read-only session for GET + db: AsyncSession = Depends(get_session), current_user: UserModel = Depends(current_active_user), ): """Retrieves the active invite code for the group. Requires group membership (owner/admin to be stricter later if needed).""" logger.info(f"User {current_user.email} attempting to get active invite for group {group_id}") - # Permission check: Ensure user is a member of the group to view invite code - # Using get_user_role_in_group which also checks membership indirectly user_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=current_user.id) if user_role is None: # Not a member logger.warning(f"Permission denied: User {current_user.email} is not a member of group {group_id} and cannot view invite code.") - # More specific error or let GroupPermissionError handle if we want to be generic raise GroupMembershipError(group_id, "view invite code for this group (not a member)") - # Fetch the active invite for the group invite = await crud_invite.get_active_invite_for_group(db, group_id=group_id) if not invite: - # This case means no active (non-expired, active=true) invite exists. - # The frontend can then prompt to generate one. logger.info(f"No active invite code found for group {group_id} when requested by {current_user.email}") raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -190,7 +175,7 @@ async def get_group_active_invite( ) logger.info(f"User {current_user.email} retrieved active invite code for group {group_id}") - return invite # Pydantic will convert InviteModel to InviteCodePublic + return invite @router.delete( "/{group_id}/leave", @@ -210,27 +195,22 @@ async def leave_group( if user_role is None: raise GroupMembershipError(group_id, "leave (you are not a member)") - # Check if owner is the last member if user_role == UserRoleEnum.owner: member_count = await crud_group.get_group_member_count(db, group_id) if member_count <= 1: - # Delete the group since owner is the last member logger.info(f"Owner {current_user.email} is the last member. Deleting group {group_id}") await crud_group.delete_group(db, group_id) return Message(detail="Group deleted as you were the last member") - # Proceed with removal for non-owner or if there are other members deleted = await crud_group.remove_user_from_group(db, group_id=group_id, user_id=current_user.id) if not deleted: - # Should not happen if role check passed, but handle defensively logger.error(f"Failed to remove user {current_user.email} from group {group_id} despite being a member.") raise GroupOperationError("Failed to leave group") logger.info(f"User {current_user.email} successfully left group {group_id}") return Message(detail="Successfully left the group") -# --- Optional: Remove Member Endpoint --- @router.delete( "/{group_id}/members/{user_id_to_remove}", response_model=Message, @@ -247,21 +227,17 @@ async def remove_group_member( logger.info(f"Owner {current_user.email} attempting to remove user {user_id_to_remove} from group {group_id}") owner_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=current_user.id) - # --- Permission Check --- if owner_role != UserRoleEnum.owner: logger.warning(f"Permission denied: User {current_user.email} (role: {owner_role}) cannot remove members from group {group_id}") raise GroupPermissionError(group_id, "remove members") - # Prevent owner removing themselves via this endpoint if current_user.id == user_id_to_remove: raise GroupValidationError("Owner cannot remove themselves using this endpoint. Use 'Leave Group' instead.") - # Check if target user is actually in the group target_role = await crud_group.get_user_role_in_group(db, group_id=group_id, user_id=user_id_to_remove) if target_role is None: raise GroupMembershipError(group_id, "remove this user (they are not a member)") - # Proceed with removal deleted = await crud_group.remove_user_from_group(db, group_id=group_id, user_id=user_id_to_remove) if not deleted: @@ -279,19 +255,17 @@ async def remove_group_member( ) async def read_group_lists( group_id: int, - db: AsyncSession = Depends(get_session), # Use read-only session for GET + db: AsyncSession = Depends(get_session), current_user: UserModel = Depends(current_active_user), ): """Retrieves all lists belonging to a specific group, if the user is a member.""" logger.info(f"User {current_user.email} requesting lists for group ID: {group_id}") - # Check if user is a member first is_member = await crud_group.is_user_member(db=db, group_id=group_id, user_id=current_user.id) if not is_member: logger.warning(f"Access denied: User {current_user.email} not member of group {group_id}") raise GroupMembershipError(group_id, "view group lists") - # Get all lists for the user and filter by group_id lists = await crud_list.get_lists_for_user(db=db, user_id=current_user.id) group_lists = [list for list in lists if list.group_id == group_id] @@ -311,7 +285,6 @@ async def generate_group_chore_schedule( ): """Generates a round-robin chore schedule for a group.""" logger.info(f"User {current_user.email} generating chore schedule for group {group_id}") - # Permission check: ensure user is a member (or owner/admin if stricter rules are needed) if not await crud_group.is_user_member(db, group_id, current_user.id): raise GroupMembershipError(group_id, "generate chore schedule for this group") @@ -342,7 +315,6 @@ async def get_group_chore_history( ): """Retrieves all chore-related history for a specific group.""" logger.info(f"User {current_user.email} requesting chore history for group {group_id}") - # Permission check if not await crud_group.is_user_member(db, group_id, current_user.id): raise GroupMembershipError(group_id, "view chore history for this group") diff --git a/be/app/api/v1/endpoints/health.py b/be/app/api/v1/endpoints/health.py index fb2b167..75c6e4b 100644 --- a/be/app/api/v1/endpoints/health.py +++ b/be/app/api/v1/endpoints/health.py @@ -1,4 +1,3 @@ -# app/api/v1/endpoints/health.py import logging from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession @@ -7,7 +6,6 @@ from sqlalchemy.sql import text from app.database import get_transactional_session from app.schemas.health import HealthStatus from app.core.exceptions import DatabaseConnectionError - logger = logging.getLogger(__name__) router = APIRouter() @@ -22,17 +20,9 @@ async def check_health(db: AsyncSession = Depends(get_transactional_session)): """ Health check endpoint. Verifies API reachability and database connection. """ - try: - # Try executing a simple query to check DB connection - result = await db.execute(text("SELECT 1")) - if result.scalar_one() == 1: - logger.info("Health check successful: Database connection verified.") - return HealthStatus(status="ok", database="connected") - else: - # This case should ideally not happen with 'SELECT 1' - logger.error("Health check failed: Database connection check returned unexpected result.") - raise DatabaseConnectionError("Unexpected result from database connection check") - - except Exception as e: - logger.error(f"Health check failed: Database connection error - {e}", exc_info=True) - raise DatabaseConnectionError(str(e)) \ No newline at end of file + result = await db.execute(text("SELECT 1")) + if result.scalar_one() == 1: + logger.info("Health check successful: Database connection verified.") + return HealthStatus(status="ok", database="connected") + logger.error("Health check failed: Database connection check returned unexpected result.") + raise DatabaseConnectionError("Unexpected result from database connection check") \ No newline at end of file diff --git a/be/app/api/v1/endpoints/invites.py b/be/app/api/v1/endpoints/invites.py index b1be222..2573125 100644 --- a/be/app/api/v1/endpoints/invites.py +++ b/be/app/api/v1/endpoints/invites.py @@ -1,21 +1,16 @@ -# app/api/v1/endpoints/invites.py import logging -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_transactional_session from app.auth import current_active_user -from app.models import User as UserModel, UserRoleEnum +from app.models import User as UserModel from app.schemas.invite import InviteAccept -from app.schemas.message import Message from app.schemas.group import GroupPublic from app.crud import invite as crud_invite from app.crud import group as crud_group from app.core.exceptions import ( InviteNotFoundError, - InviteExpiredError, - InviteAlreadyUsedError, - InviteCreationError, GroupNotFoundError, GroupMembershipError, GroupOperationError @@ -25,7 +20,7 @@ logger = logging.getLogger(__name__) router = APIRouter() @router.post( - "/accept", # Route relative to prefix "/invites" + "/accept", response_model=GroupPublic, summary="Accept Group Invite", tags=["Invites"] @@ -37,42 +32,33 @@ async def accept_invite( ): """Accepts a group invite using the provided invite code.""" logger.info(f"User {current_user.email} attempting to accept invite code: {invite_in.code}") - - # Get the invite - this function should only return valid, active invites + invite = await crud_invite.get_active_invite_by_code(db, code=invite_in.code) if not invite: logger.warning(f"Invalid or inactive invite code attempted by user {current_user.email}: {invite_in.code}") - # We can use a more generic error or a specific one. InviteNotFound is reasonable. raise InviteNotFoundError(invite_in.code) - # Check if group still exists group = await crud_group.get_group_by_id(db, group_id=invite.group_id) if not group: logger.error(f"Group {invite.group_id} not found for invite {invite_in.code}") raise GroupNotFoundError(invite.group_id) - # Check if user is already a member is_member = await crud_group.is_user_member(db, group_id=invite.group_id, user_id=current_user.id) if is_member: logger.warning(f"User {current_user.email} already a member of group {invite.group_id}") raise GroupMembershipError(invite.group_id, "join (already a member)") - # Add user to the group added_to_group = await crud_group.add_user_to_group(db, group_id=invite.group_id, user_id=current_user.id) if not added_to_group: logger.error(f"Failed to add user {current_user.email} to group {invite.group_id} during invite acceptance.") - # This could be a race condition or other issue, treat as an operational error. raise GroupOperationError("Failed to add user to group.") - # Deactivate the invite so it cannot be used again await crud_invite.deactivate_invite(db, invite=invite) logger.info(f"User {current_user.email} successfully joined group {invite.group_id} via invite {invite_in.code}") - # Re-fetch the group to get the updated member list updated_group = await crud_group.get_group_by_id(db, group_id=invite.group_id) if not updated_group: - # This should ideally not happen as we found it before logger.error(f"Could not re-fetch group {invite.group_id} after user {current_user.email} joined.") raise GroupNotFoundError(invite.group_id) diff --git a/be/app/api/v1/endpoints/items.py b/be/app/api/v1/endpoints/items.py index 854c161..70aa4fa 100644 --- a/be/app/api/v1/endpoints/items.py +++ b/be/app/api/v1/endpoints/items.py @@ -1,4 +1,4 @@ -# app/api/v1/endpoints/items.py + import logging from typing import List as PyList, Optional @@ -6,21 +6,17 @@ from fastapi import APIRouter, Depends, HTTPException, status, Response, Query from sqlalchemy.ext.asyncio import AsyncSession from app.database import get_transactional_session -from app.auth import current_active_user -# --- Import Models Correctly --- from app.models import User as UserModel -from app.models import Item as ItemModel # <-- IMPORT Item and alias it -# --- End Import Models --- +from app.models import Item as ItemModel from app.schemas.item import ItemCreate, ItemUpdate, ItemPublic from app.crud import item as crud_item from app.crud import list as crud_list from app.core.exceptions import ItemNotFoundError, ListPermissionError, ConflictError +from app.auth import current_active_user logger = logging.getLogger(__name__) router = APIRouter() -# --- Helper Dependency for Item Permissions --- -# Now ItemModel is defined before being used as a type hint async def get_item_and_verify_access( item_id: int, db: AsyncSession = Depends(get_transactional_session), @@ -31,19 +27,15 @@ async def get_item_and_verify_access( if not item_db: raise ItemNotFoundError(item_id) - # Check permission on the parent list try: await crud_list.check_list_permission(db=db, list_id=item_db.list_id, user_id=current_user.id) except ListPermissionError as e: - # Re-raise with a more specific message raise ListPermissionError(item_db.list_id, "access this item's list") return item_db -# --- Endpoints --- - @router.post( - "/lists/{list_id}/items", # Nested under lists + "/lists/{list_id}/items", response_model=ItemPublic, status_code=status.HTTP_201_CREATED, summary="Add Item to List", @@ -56,13 +48,11 @@ async def create_list_item( current_user: UserModel = Depends(current_active_user), ): """Adds a new item to a specific list. User must have access to the list.""" - user_email = current_user.email # Access email attribute before async operations + user_email = current_user.email logger.info(f"User {user_email} adding item to list {list_id}: {item_in.name}") - # Verify user has access to the target list try: await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) except ListPermissionError as e: - # Re-raise with a more specific message raise ListPermissionError(list_id, "add items to this list") created_item = await crud_item.create_item( @@ -73,7 +63,7 @@ async def create_list_item( @router.get( - "/lists/{list_id}/items", # Nested under lists + "/lists/{list_id}/items", response_model=PyList[ItemPublic], summary="List Items in List", tags=["Items"] @@ -82,16 +72,13 @@ async def read_list_items( list_id: int, db: AsyncSession = Depends(get_transactional_session), current_user: UserModel = Depends(current_active_user), - # Add sorting/filtering params later if needed: sort_by: str = 'created_at', order: str = 'asc' ): """Retrieves all items for a specific list if the user has access.""" - user_email = current_user.email # Access email attribute before async operations + user_email = current_user.email logger.info(f"User {user_email} listing items for list {list_id}") - # Verify user has access to the list try: await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) - except ListPermissionError as e: - # Re-raise with a more specific message + except ListPermissionError as e: raise ListPermissionError(list_id, "view items in this list") items = await crud_item.get_items_by_list_id(db=db, list_id=list_id) @@ -99,7 +86,7 @@ async def read_list_items( @router.put( - "/lists/{list_id}/items/{item_id}", # Nested under lists + "/lists/{list_id}/items/{item_id}", response_model=ItemPublic, summary="Update Item", tags=["Items"], @@ -111,9 +98,9 @@ async def update_item( list_id: int, item_id: int, item_in: ItemUpdate, - item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access + item_db: ItemModel = Depends(get_item_and_verify_access), db: AsyncSession = Depends(get_transactional_session), - current_user: UserModel = Depends(current_active_user), # Need user ID for completed_by + current_user: UserModel = Depends(current_active_user), ): """ Updates an item's details (name, quantity, is_complete, price). @@ -122,9 +109,8 @@ async def update_item( If the version does not match, a 409 Conflict is returned. Sets/unsets `completed_by_id` based on `is_complete` flag. """ - user_email = current_user.email # Access email attribute before async operations + user_email = current_user.email logger.info(f"User {user_email} attempting to update item ID: {item_id} with version {item_in.version}") - # Permission check is handled by get_item_and_verify_access dependency try: updated_item = await crud_item.update_item( @@ -141,7 +127,7 @@ async def update_item( @router.delete( - "/lists/{list_id}/items/{item_id}", # Nested under lists + "/lists/{list_id}/items/{item_id}", status_code=status.HTTP_204_NO_CONTENT, summary="Delete Item", tags=["Items"], @@ -153,18 +139,16 @@ async def delete_item( list_id: int, item_id: int, expected_version: Optional[int] = Query(None, description="The expected version of the item to delete for optimistic locking."), - item_db: ItemModel = Depends(get_item_and_verify_access), # Use dependency to get item and check list access + item_db: ItemModel = Depends(get_item_and_verify_access), db: AsyncSession = Depends(get_transactional_session), - current_user: UserModel = Depends(current_active_user), # Log who deleted it + current_user: UserModel = Depends(current_active_user), ): """ Deletes an item. User must have access to the list the item belongs to. If `expected_version` is provided and does not match the item's current version, a 409 Conflict is returned. """ - user_email = current_user.email # Access email attribute before async operations - logger.info(f"User {user_email} attempting to delete item ID: {item_id}, expected version: {expected_version}") - # Permission check is handled by get_item_and_verify_access dependency + user_email = current_user.email if expected_version is not None and item_db.version != expected_version: logger.warning( diff --git a/be/app/api/v1/endpoints/lists.py b/be/app/api/v1/endpoints/lists.py index e26bd14..b8129e9 100644 --- a/be/app/api/v1/endpoints/lists.py +++ b/be/app/api/v1/endpoints/lists.py @@ -1,34 +1,27 @@ -# app/api/v1/endpoints/lists.py import logging -from typing import List as PyList, Optional # Alias for Python List type hint - -from fastapi import APIRouter, Depends, HTTPException, status, Response, Query # Added Query +from typing import List as PyList, Optional +from fastapi import APIRouter, Depends, HTTPException, status, Response, Query from sqlalchemy.ext.asyncio import AsyncSession - from app.database import get_transactional_session from app.auth import current_active_user from app.models import User as UserModel from app.schemas.list import ListCreate, ListUpdate, ListPublic, ListDetail -from app.schemas.message import Message # For simple responses from app.crud import list as crud_list -from app.crud import group as crud_group # Need for group membership check +from app.crud import group as crud_group from app.schemas.list import ListStatus, ListStatusWithId -from app.schemas.expense import ExpensePublic # Import ExpensePublic +from app.schemas.expense import ExpensePublic from app.core.exceptions import ( GroupMembershipError, - ListNotFoundError, - ListPermissionError, - ListStatusNotFoundError, - ConflictError, # Added ConflictError - DatabaseIntegrityError # Added DatabaseIntegrityError + ConflictError, + DatabaseIntegrityError ) logger = logging.getLogger(__name__) router = APIRouter() @router.post( - "", # Route relative to prefix "/lists" - response_model=ListPublic, # Return basic list info on creation + "", + response_model=ListPublic, status_code=status.HTTP_201_CREATED, summary="Create New List", tags=["Lists"], @@ -53,7 +46,6 @@ async def create_list( logger.info(f"User {current_user.email} creating list: {list_in.name}") group_id = list_in.group_id - # Permission Check: If sharing with a group, verify membership if group_id: is_member = await crud_group.is_user_member(db, group_id=group_id, user_id=current_user.id) if not is_member: @@ -65,9 +57,7 @@ async def create_list( logger.info(f"List '{created_list.name}' (ID: {created_list.id}) created successfully for user {current_user.email}.") return created_list except DatabaseIntegrityError as e: - # Check if this is a unique constraint violation if "unique constraint" in str(e).lower(): - # Find the existing list with the same name in the group existing_list = await crud_list.get_list_by_name_and_group( db=db, name=list_in.name, @@ -81,20 +71,18 @@ async def create_list( detail=f"A list named '{list_in.name}' already exists in this group.", headers={"X-Existing-List": str(existing_list.id)} ) - # If it's not a unique constraint or we couldn't find the existing list, re-raise raise @router.get( - "", # Route relative to prefix "/lists" - response_model=PyList[ListDetail], # Return a list of detailed list info including items + "", + response_model=PyList[ListDetail], summary="List Accessible Lists", tags=["Lists"] ) async def read_lists( db: AsyncSession = Depends(get_transactional_session), current_user: UserModel = Depends(current_active_user), - # Add pagination parameters later if needed: skip: int = 0, limit: int = 100 ): """ Retrieves lists accessible to the current user: @@ -128,7 +116,6 @@ async def read_lists_statuses( statuses = await crud_list.get_lists_statuses_by_ids(db=db, list_ids=ids, user_id=current_user.id) - # The CRUD function returns a list of Row objects, so we map them to the Pydantic model return [ ListStatusWithId( id=s.id, @@ -141,7 +128,7 @@ async def read_lists_statuses( @router.get( "/{list_id}", - response_model=ListDetail, # Return detailed list info including items + response_model=ListDetail, summary="Get List Details", tags=["Lists"] ) @@ -155,17 +142,16 @@ async def read_list( if the user has permission (creator or group member). """ logger.info(f"User {current_user.email} requesting details for list ID: {list_id}") - # The check_list_permission function will raise appropriate exceptions list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) return list_db @router.put( "/{list_id}", - response_model=ListPublic, # Return updated basic info + response_model=ListPublic, summary="Update List", tags=["Lists"], - responses={ # Add 409 to responses + responses={ status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified by someone else"} } ) @@ -188,22 +174,20 @@ async def update_list( updated_list = await crud_list.update_list(db=db, list_db=list_db, list_in=list_in) logger.info(f"List {list_id} updated successfully by user {current_user.email} to version {updated_list.version}.") return updated_list - except ConflictError as e: # Catch and re-raise as HTTPException for proper FastAPI response + except ConflictError as e: logger.warning(f"Conflict updating list {list_id} for user {current_user.email}: {str(e)}") raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) - except Exception as e: # Catch other potential errors from crud operation + except Exception as e: logger.error(f"Error updating list {list_id} for user {current_user.email}: {str(e)}") - # Consider a more generic error, but for now, let's keep it specific if possible - # Re-raising might be better if crud layer already raises appropriate HTTPExceptions raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred while updating the list.") @router.delete( "/{list_id}", - status_code=status.HTTP_204_NO_CONTENT, # Standard for successful DELETE with no body + status_code=status.HTTP_204_NO_CONTENT, summary="Delete List", tags=["Lists"], - responses={ # Add 409 to responses + responses={ status.HTTP_409_CONFLICT: {"description": "Conflict: List has been modified, cannot delete specified version"} } ) @@ -219,7 +203,6 @@ async def delete_list( a 409 Conflict is returned. """ logger.info(f"User {current_user.email} attempting to delete list ID: {list_id}, expected version: {expected_version}") - # Use the helper, requiring creator permission list_db = await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id, require_creator=True) if expected_version is not None and list_db.version != expected_version: @@ -253,7 +236,6 @@ async def read_list_status( if the user has permission (creator or group member). """ logger.info(f"User {current_user.email} requesting status for list ID: {list_id}") - # The check_list_permission is not needed here as get_list_status handles not found await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) return await crud_list.get_list_status(db=db, list_id=list_id) @@ -278,9 +260,7 @@ async def read_list_expenses( logger.info(f"User {current_user.email} requesting expenses for list ID: {list_id}") - # Check if user has permission to access this list await crud_list.check_list_permission(db=db, list_id=list_id, user_id=current_user.id) - # Get expenses for this list expenses = await crud_expense.get_expenses_for_list(db, list_id=list_id, skip=skip, limit=limit) return expenses \ No newline at end of file diff --git a/be/app/api/v1/endpoints/ocr.py b/be/app/api/v1/endpoints/ocr.py index 9a21689..b309a7a 100644 --- a/be/app/api/v1/endpoints/ocr.py +++ b/be/app/api/v1/endpoints/ocr.py @@ -1,17 +1,12 @@ import logging -from typing import List - -from fastapi import APIRouter, Depends, UploadFile, File, HTTPException, status -from google.api_core import exceptions as google_exceptions - +from fastapi import APIRouter, Depends, UploadFile, File from app.auth import current_active_user from app.models import User as UserModel from app.schemas.ocr import OcrExtractResponse from app.core.gemini import GeminiOCRService, gemini_initialization_error -from app.core.exceptions import ( +from app.core.exceptions import ( OCRServiceUnavailableError, OCRServiceConfigError, - OCRUnexpectedError, OCRQuotaExceededError, InvalidFileTypeError, FileTooLargeError, @@ -37,26 +32,22 @@ async def ocr_extract_items( Accepts an image upload, sends it to Gemini Flash with a prompt to extract shopping list items, and returns the parsed items. """ - # Check if Gemini client initialized correctly if gemini_initialization_error: logger.error("OCR endpoint called but Gemini client failed to initialize.") raise OCRServiceUnavailableError(gemini_initialization_error) logger.info(f"User {current_user.email} uploading image '{image_file.filename}' for OCR extraction.") - # --- File Validation --- if image_file.content_type not in settings.ALLOWED_IMAGE_TYPES: logger.warning(f"Invalid file type uploaded by {current_user.email}: {image_file.content_type}") raise InvalidFileTypeError() - # Simple size check contents = await image_file.read() if len(contents) > settings.MAX_FILE_SIZE_MB * 1024 * 1024: logger.warning(f"File too large uploaded by {current_user.email}: {len(contents)} bytes") raise FileTooLargeError() try: - # Use the ocr_service instance instead of the standalone function extracted_items = await ocr_service.extract_items(image_data=contents) logger.info(f"Successfully extracted {len(extracted_items)} items for user {current_user.email}.") @@ -72,5 +63,4 @@ async def ocr_extract_items( raise OCRProcessingError(str(e)) finally: - # Ensure file handle is closed await image_file.close() \ No newline at end of file diff --git a/be/app/auth.py b/be/app/auth.py index f8a7518..02b6522 100644 --- a/be/app/auth.py +++ b/be/app/auth.py @@ -21,11 +21,9 @@ from .database import get_session from .models import User from .config import settings -# OAuth2 configuration config = Config('.env') oauth = OAuth(config) -# Google OAuth2 setup oauth.register( name='google', server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', @@ -35,7 +33,6 @@ oauth.register( } ) -# Apple OAuth2 setup oauth.register( name='apple', server_metadata_url='https://appleid.apple.com/.well-known/openid-configuration', @@ -45,13 +42,11 @@ oauth.register( } ) -# Custom Bearer Response with Refresh Token class BearerResponseWithRefresh(BaseModel): access_token: str refresh_token: str token_type: str = "bearer" -# Custom Bearer Transport that supports refresh tokens class BearerTransportWithRefresh(BearerTransport): async def get_login_response(self, token: str, refresh_token: str = None) -> Response: if refresh_token: @@ -61,14 +56,12 @@ class BearerTransportWithRefresh(BearerTransport): token_type="bearer" ) else: - # Fallback to standard response if no refresh token bearer_response = { "access_token": token, "token_type": "bearer" } return JSONResponse(bearer_response.dict() if hasattr(bearer_response, 'dict') else bearer_response) -# Custom Authentication Backend with Refresh Token Support class AuthenticationBackendWithRefresh(AuthenticationBackend): def __init__( self, @@ -83,7 +76,6 @@ class AuthenticationBackendWithRefresh(AuthenticationBackend): self.get_refresh_strategy = get_refresh_strategy async def login(self, strategy, user) -> Response: - # Generate both access and refresh tokens access_token = await strategy.write_token(user) refresh_strategy = self.get_refresh_strategy() refresh_token = await refresh_strategy.write_token(user) @@ -124,17 +116,14 @@ async def get_user_db(session: AsyncSession = Depends(get_session)): async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)): yield UserManager(user_db) -# Updated transport with refresh token support bearer_transport = BearerTransportWithRefresh(tokenUrl="auth/jwt/login") def get_jwt_strategy() -> JWTStrategy: return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60) -def get_refresh_jwt_strategy() -> JWTStrategy: - # Refresh tokens last longer - 7 days +def get_refresh_jwt_strategy() -> JWTStrategy: return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=7 * 24 * 60 * 60) -# Updated auth backend with refresh token support auth_backend = AuthenticationBackendWithRefresh( name="jwt", transport=bearer_transport, diff --git a/be/app/core/api_config.py b/be/app/core/api_config.py index 79af045..d8ee10b 100644 --- a/be/app/core/api_config.py +++ b/be/app/core/api_config.py @@ -1,28 +1,20 @@ -from typing import Dict, Any from app.config import settings -# API Version API_VERSION = "v1" - -# API Prefix API_PREFIX = f"/api/{API_VERSION}" -# API Endpoints class APIEndpoints: - # Auth AUTH = { "LOGIN": "/auth/login", "SIGNUP": "/auth/signup", "REFRESH_TOKEN": "/auth/refresh-token", } - # Users USERS = { "PROFILE": "/users/profile", "UPDATE_PROFILE": "/users/profile", } - # Lists LISTS = { "BASE": "/lists", "BY_ID": "/lists/{id}", @@ -30,7 +22,6 @@ class APIEndpoints: "ITEM": "/lists/{list_id}/items/{item_id}", } - # Groups GROUPS = { "BASE": "/groups", "BY_ID": "/groups/{id}", @@ -38,7 +29,6 @@ class APIEndpoints: "MEMBERS": "/groups/{group_id}/members", } - # Invites INVITES = { "BASE": "/invites", "BY_ID": "/invites/{id}", @@ -46,12 +36,10 @@ class APIEndpoints: "DECLINE": "/invites/{id}/decline", } - # OCR OCR = { "PROCESS": "/ocr/process", } - # Financials FINANCIALS = { "EXPENSES": "/financials/expenses", "EXPENSE": "/financials/expenses/{id}", @@ -59,12 +47,10 @@ class APIEndpoints: "SETTLEMENT": "/financials/settlements/{id}", } - # Health HEALTH = { "CHECK": "/health", } -# API Metadata API_METADATA = { "title": settings.API_TITLE, "description": settings.API_DESCRIPTION, @@ -74,7 +60,6 @@ API_METADATA = { "redoc_url": settings.API_REDOC_URL, } -# API Tags API_TAGS = [ {"name": "Authentication", "description": "Authentication and authorization endpoints"}, {"name": "Users", "description": "User management endpoints"}, @@ -86,7 +71,7 @@ API_TAGS = [ {"name": "Health", "description": "Health check endpoints"}, ] -# Helper function to get full API URL + def get_api_url(endpoint: str, **kwargs) -> str: """ Get the full API URL for an endpoint. diff --git a/be/app/core/chore_utils.py b/be/app/core/chore_utils.py index ed7516f..6317a78 100644 --- a/be/app/core/chore_utils.py +++ b/be/app/core/chore_utils.py @@ -48,7 +48,6 @@ def calculate_next_due_date( today = date.today() reference_future_date = max(today, base_date) - # This loop ensures the next_due date is always in the future relative to the reference_future_date. while next_due <= reference_future_date: current_base_for_recalc = next_due @@ -70,9 +69,7 @@ def calculate_next_due_date( else: # Should not be reached break - # Safety break: if date hasn't changed, interval is zero or logic error. if next_due == current_base_for_recalc: - # Log error ideally, then advance by one day to prevent infinite loop. next_due += timedelta(days=1) break diff --git a/be/app/core/exceptions.py b/be/app/core/exceptions.py index c9a2d1f..0b5f936 100644 --- a/be/app/core/exceptions.py +++ b/be/app/core/exceptions.py @@ -362,4 +362,3 @@ class PermissionDeniedError(HTTPException): detail=detail ) -# Financials & Cost Splitting specific errors \ No newline at end of file diff --git a/be/app/core/gemini.py b/be/app/core/gemini.py index a8c308c..6636874 100644 --- a/be/app/core/gemini.py +++ b/be/app/core/gemini.py @@ -1,8 +1,6 @@ -# app/core/gemini.py import logging from typing import List import google.generativeai as genai -from google.generativeai.types import HarmCategory, HarmBlockThreshold # For safety settings from google.api_core import exceptions as google_exceptions from app.config import settings from app.core.exceptions import ( @@ -15,15 +13,12 @@ from app.core.exceptions import ( logger = logging.getLogger(__name__) -# --- Global variable to hold the initialized model client --- gemini_flash_client = None -gemini_initialization_error = None # Store potential init error +gemini_initialization_error = None -# --- Configure and Initialize --- try: if settings.GEMINI_API_KEY: genai.configure(api_key=settings.GEMINI_API_KEY) - # Initialize the specific model we want to use gemini_flash_client = genai.GenerativeModel( model_name=settings.GEMINI_MODEL_NAME, generation_config=genai.types.GenerationConfig( @@ -32,18 +27,15 @@ try: ) logger.info(f"Gemini AI client initialized successfully for model '{settings.GEMINI_MODEL_NAME}'.") else: - # Store error if API key is missing gemini_initialization_error = "GEMINI_API_KEY not configured. Gemini client not initialized." logger.error(gemini_initialization_error) except Exception as e: - # Catch any other unexpected errors during initialization gemini_initialization_error = f"Failed to initialize Gemini AI client: {e}" - logger.exception(gemini_initialization_error) # Log full traceback - gemini_flash_client = None # Ensure client is None on error + logger.exception(gemini_initialization_error) + gemini_flash_client = None -# --- Function to get the client (optional, allows checking error) --- def get_gemini_client(): """ Returns the initialized Gemini client instance. @@ -52,23 +44,172 @@ def get_gemini_client(): if gemini_initialization_error: raise OCRServiceConfigError() if gemini_flash_client is None: - # This case should ideally be covered by the check above, but as a safeguard: raise OCRServiceConfigError() return gemini_flash_client -# Define the prompt as a constant OCR_ITEM_EXTRACTION_PROMPT = """ -Extract the shopping list items from this image. -List each distinct item on a new line. -Ignore prices, quantities, store names, discounts, taxes, totals, and other non-item text. -Focus only on the names of the products or items to be purchased. -If the image does not appear to be a shopping list or receipt, state that clearly. -Example output for a grocery list: -Milk -Eggs -Bread -Apples -Organic Bananas +**ROLE & GOAL** + +You are an expert AI assistant specializing in Optical Character Recognition (OCR) and structured data extraction. Your primary function is to act as a "Shopping List Digitizer." + +Your goal is to meticulously analyze the provided image of a shopping list, which is likely handwritten, and convert it into a structured, machine-readable JSON format. You must be accurate, infer context where necessary, and handle the inherent ambiguities of handwriting and informal list-making. + +**INPUT** + +You will receive a single image (`[Image]`). This image contains a shopping list. It may be: +* Neatly written or very messy. +* On lined paper, a whiteboard, a napkin, or a dedicated notepad. +* Containing doodles, stains, or other visual noise. +* Using various formats (bullet points, numbered lists, columns, simple line breaks). +* could be in English or in German. + +**CORE TASK: STEP-BY-STEP ANALYSIS** + +Follow these steps precisely: + +1. **Initial Image Analysis & OCR:** + * Perform an advanced OCR scan on the entire image to transcribe all visible text. + * Pay close attention to the spatial layout. Identify headings, columns, and line items. Note which text elements appear to be grouped together. + +2. **Item Identification & Filtering:** + * Differentiate between actual list items and non-item elements. + * **INCLUDE:** Items intended for purchase. + * **EXCLUDE:** List titles (e.g., "GROCERIES," "Target List"), dates, doodles, unrelated notes, or stray marks. Capture the list title separately if one exists. + +3. **Detailed Extraction for Each Item:** + For every single item you identify, extract the following attributes. If an attribute is not present, use `null`. + + * `item_name` (string): The primary name of the product. + * **Standardize:** Normalize the name. (e.g., "B. Powder" -> "Baking Powder", "A. Juice" -> "Apple Juice"). + * **Contextual Guessing:** If a word is poorly written, use the context of a shopping list to make an educated guess. (e.g., "Ciffee" is almost certainly "Coffee"). + + * `quantity` (number or string): The amount needed. + * If a number is present (e.g., "**2** milks"), extract the number `2`. + * If it's a word (e.g., "**a dozen** eggs"), extract the string `"a dozen"`. + * If no quantity is specified (e.g., "Bread"), infer a default quantity of `1`. + + * `unit` (string): The unit of measurement or packaging. + * Examples: "kg", "lbs", "liters", "gallons", "box", "can", "bag", "bunch". + * Infer where possible (e.g., for "2 Milks," the unit could be inferred as "cartons" or "gallons" depending on regional context, but it's safer to leave it `null` if not explicitly stated). + + * `notes` (string): Any additional descriptive text. + * Examples: "low-sodium," "organic," "brand name (Tide)," "for the cake," "get the ripe ones." + + * `category` (string): Infer a logical category for the item. + * Use common grocery store categories: `Produce`, `Dairy & Eggs`, `Meat & Seafood`, `Pantry`, `Frozen`, `Bakery`, `Beverages`, `Household`, `Personal Care`. + * If the list itself has category headings (e.g., a "DAIRY" section), use those first. + + * `original_text` (string): Provide the exact, unaltered text that your OCR transcribed for this entire line item. This is crucial for verification. + + * `is_crossed_out` (boolean): Set to `true` if the item is struck through, crossed out, or clearly marked as completed. Otherwise, set to `false`. + +**HANDLING AMBIGUITIES AND EDGE CASES** + +* **Illegible Text:** If a line or word is completely unreadable, set `item_name` to `"UNREADABLE"` and place the garbled OCR attempt in the `original_text` field. +* **Abbreviations:** Expand common shopping list abbreviations (e.g., "OJ" -> "Orange Juice", "TP" -> "Toilet Paper", "AVOs" -> "Avocados", "G. Beef" -> "Ground Beef"). +* **Implicit Items:** If a line is vague like "Snacks for kids," list it as is. Do not invent specific items. +* **Multi-item Lines:** If a line contains multiple items (e.g., "Onions, Garlic, Ginger"), split them into separate item objects. + +**OUTPUT FORMAT** + +Your final output MUST be a single JSON object with the following structure. Do not include any explanatory text before or after the JSON block. + +```json +{ + "list_title": "string or null", + "items": [ + { + "item_name": "string", + "quantity": "number or string", + "unit": "string or null", + "category": "string", + "notes": "string or null", + "original_text": "string", + "is_crossed_out": "boolean" + } + ], + "summary": { + "total_items": "integer", + "unread_items": "integer", + "crossed_out_items": "integer" + } +} +``` + +**EXAMPLE WALKTHROUGH** + +* **IF THE IMAGE SHOWS:** A crumpled sticky note with the title "Stuff for tonight" and the items: + * `2x Chicken Breasts` + * `~~Baguette~~` (this item is crossed out) + * `Salad mix (bag)` + * `Tomatos` (misspelled) + * `Choc Ice Cream` + +* **YOUR JSON OUTPUT SHOULD BE:** + +```json +{ + "list_title": "Stuff for tonight", + "items": [ + { + "item_name": "Chicken Breasts", + "quantity": 2, + "unit": null, + "category": "Meat & Seafood", + "notes": null, + "original_text": "2x Chicken Breasts", + "is_crossed_out": false + }, + { + "item_name": "Baguette", + "quantity": 1, + "unit": null, + "category": "Bakery", + "notes": null, + "original_text": "Baguette", + "is_crossed_out": true + }, + { + "item_name": "Salad Mix", + "quantity": 1, + "unit": "bag", + "category": "Produce", + "notes": null, + "original_text": "Salad mix (bag)", + "is_crossed_out": false + }, + { + "item_name": "Tomatoes", + "quantity": 1, + "unit": null, + "category": "Produce", + "notes": null, + "original_text": "Tomatos", + "is_crossed_out": false + }, + { + "item_name": "Chocolate Ice Cream", + "quantity": 1, + "unit": null, + "category": "Frozen", + "notes": null, + "original_text": "Choc Ice Cream", + "is_crossed_out": false + } + ], + "summary": { + "total_items": 5, + "unread_items": 0, + "crossed_out_items": 1 + } +} +``` + +**FINAL INSTRUCTION** + +If the image provided is not a shopping list or is completely blank/unintelligible, respond with a JSON object where the `items` array is empty and add a note in the `list_title` field, such as "Image does not appear to be a shopping list." + +Now, analyze the provided image and generate the JSON output. """ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = "image/jpeg") -> List[str]: @@ -92,29 +233,22 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = " try: client = get_gemini_client() # Raises OCRServiceConfigError if not initialized - # Prepare image part for multimodal input image_part = { "mime_type": mime_type, "data": image_bytes } - # Prepare the full prompt content prompt_parts = [ - settings.OCR_ITEM_EXTRACTION_PROMPT, # Text prompt first - image_part # Then the image + settings.OCR_ITEM_EXTRACTION_PROMPT, + image_part ] logger.info("Sending image to Gemini for item extraction...") - - # Make the API call - # Use generate_content_async for async FastAPI + response = await client.generate_content_async(prompt_parts) - # --- Process the response --- - # Check for safety blocks or lack of content if not response.candidates or not response.candidates[0].content.parts: logger.warning("Gemini response blocked or empty.", extra={"response": response}) - # Check finish_reason if available finish_reason = response.candidates[0].finish_reason if response.candidates else 'UNKNOWN' safety_ratings = response.candidates[0].safety_ratings if response.candidates else 'N/A' if finish_reason == 'SAFETY': @@ -122,18 +256,13 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = " else: raise OCRUnexpectedError() - # Extract text - assumes the first part of the first candidate is the text response - raw_text = response.text # response.text is a shortcut for response.candidates[0].content.parts[0].text + raw_text = response.text logger.info("Received raw text from Gemini.") - # logger.debug(f"Gemini Raw Text:\n{raw_text}") # Optional: Log full response text - # Parse the text response items = [] - for line in raw_text.splitlines(): # Split by newline - cleaned_line = line.strip() # Remove leading/trailing whitespace - # Basic filtering: ignore empty lines and potential non-item lines - if cleaned_line and len(cleaned_line) > 1: # Ignore very short lines too? - # Add more sophisticated filtering if needed (e.g., regex, keyword check) + for line in raw_text.splitlines(): + cleaned_line = line.strip() + if cleaned_line and len(cleaned_line) > 1: items.append(cleaned_line) logger.info(f"Extracted {len(items)} potential items.") @@ -145,12 +274,9 @@ async def extract_items_from_image_gemini(image_bytes: bytes, mime_type: str = " raise OCRQuotaExceededError() raise OCRServiceUnavailableError() except (OCRServiceConfigError, OCRQuotaExceededError, OCRServiceUnavailableError, OCRProcessingError, OCRUnexpectedError): - # Re-raise specific OCR exceptions raise except Exception as e: - # Catch other unexpected errors during generation or processing logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True) - # Wrap in a custom exception raise OCRUnexpectedError() class GeminiOCRService: @@ -186,27 +312,22 @@ class GeminiOCRService: OCRUnexpectedError: For any other unexpected errors. """ try: - # Create image part image_parts = [{"mime_type": mime_type, "data": image_data}] - # Generate content response = await self.model.generate_content_async( contents=[settings.OCR_ITEM_EXTRACTION_PROMPT, *image_parts] ) - - # Process response + if not response.text: logger.warning("Gemini response is empty") raise OCRUnexpectedError() - # Check for safety blocks if hasattr(response, 'candidates') and response.candidates and hasattr(response.candidates[0], 'finish_reason'): finish_reason = response.candidates[0].finish_reason if finish_reason == 'SAFETY': safety_ratings = response.candidates[0].safety_ratings if hasattr(response.candidates[0], 'safety_ratings') else 'N/A' raise OCRProcessingError(f"Gemini response blocked due to safety settings. Ratings: {safety_ratings}") - # Split response into lines and clean up items = [] for line in response.text.splitlines(): cleaned_line = line.strip() @@ -222,7 +343,6 @@ class GeminiOCRService: raise OCRQuotaExceededError() raise OCRServiceUnavailableError() except (OCRServiceConfigError, OCRQuotaExceededError, OCRServiceUnavailableError, OCRProcessingError, OCRUnexpectedError): - # Re-raise specific OCR exceptions raise except Exception as e: logger.error(f"Unexpected error during Gemini item extraction: {e}", exc_info=True) diff --git a/be/app/core/scheduler.py b/be/app/core/scheduler.py index 9227cf2..0844979 100644 --- a/be/app/core/scheduler.py +++ b/be/app/core/scheduler.py @@ -2,7 +2,6 @@ from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore from apscheduler.executors.pool import ThreadPoolExecutor from apscheduler.triggers.cron import CronTrigger -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from app.config import settings from app.jobs.recurring_expenses import generate_recurring_expenses from app.db.session import async_session @@ -10,11 +9,8 @@ import logging logger = logging.getLogger(__name__) -# Convert async database URL to sync URL for APScheduler -# Replace postgresql+asyncpg:// with postgresql:// sync_db_url = settings.DATABASE_URL.replace('postgresql+asyncpg://', 'postgresql://') -# Configure the scheduler jobstores = { 'default': SQLAlchemyJobStore(url=sync_db_url) } @@ -36,7 +32,10 @@ scheduler = AsyncIOScheduler( ) async def run_recurring_expenses_job(): - """Wrapper function to run the recurring expenses job with a database session.""" + """Wrapper function to run the recurring expenses job with a database session. + + This function is used to generate recurring expenses for the user. + """ try: async with async_session() as session: await generate_recurring_expenses(session) @@ -47,7 +46,6 @@ async def run_recurring_expenses_job(): def init_scheduler(): """Initialize and start the scheduler.""" try: - # Add the recurring expenses job scheduler.add_job( run_recurring_expenses_job, trigger=CronTrigger(hour=0, minute=0), # Run at midnight UTC @@ -56,7 +54,6 @@ def init_scheduler(): replace_existing=True ) - # Start the scheduler scheduler.start() logger.info("Scheduler started successfully") except Exception as e: diff --git a/be/app/core/security.py b/be/app/core/security.py index 197c732..6914170 100644 --- a/be/app/core/security.py +++ b/be/app/core/security.py @@ -1,20 +1,5 @@ -# app/core/security.py -from datetime import datetime, timedelta, timezone -from typing import Any, Union, Optional - -from jose import JWTError, jwt from passlib.context import CryptContext -from app.config import settings # Import settings from config - -# --- Password Hashing --- -# These functions are used for password hashing and verification -# They complement FastAPI-Users but provide direct access to the underlying password functionality -# when needed outside of the FastAPI-Users authentication flow. - -# Configure passlib context -# Using bcrypt as the default hashing scheme -# 'deprecated="auto"' will automatically upgrade hashes if needed on verification pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") def verify_password(plain_password: str, hashed_password: str) -> bool: @@ -33,7 +18,6 @@ def verify_password(plain_password: str, hashed_password: str) -> bool: try: return pwd_context.verify(plain_password, hashed_password) except Exception: - # Handle potential errors during verification (e.g., invalid hash format) return False def hash_password(password: str) -> str: @@ -48,26 +32,4 @@ def hash_password(password: str) -> str: Returns: The resulting hash string. """ - return pwd_context.hash(password) - - -# --- JSON Web Tokens (JWT) --- -# FastAPI-Users now handles all JWT token creation and validation. -# The code below is commented out because FastAPI-Users provides these features. -# It's kept for reference in case a custom implementation is needed later. - -# Example of a potential future implementation: -# def get_subject_from_token(token: str) -> Optional[str]: -# """ -# Extract the subject (user ID) from a JWT token. -# This would be used if we need to validate tokens outside of FastAPI-Users flow. -# For now, use fastapi_users.current_user dependency instead. -# """ -# # This would need to use FastAPI-Users' token verification if ever implemented -# # For example, by decoding the token using the strategy from the auth backend -# try: -# payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) -# return payload.get("sub") -# except JWTError: -# return None -# return None \ No newline at end of file + return pwd_context.hash(password) \ No newline at end of file diff --git a/be/app/crud/chore.py b/be/app/crud/chore.py index f8ca622..f63493c 100644 --- a/be/app/crud/chore.py +++ b/be/app/crud/chore.py @@ -18,16 +18,14 @@ logger = logging.getLogger(__name__) async def get_all_user_chores(db: AsyncSession, user_id: int) -> List[Chore]: """Gets all chores (personal and group) for a user in optimized queries.""" - # Get personal chores query personal_chores_query = ( - select(Chore) + select(Chore) .where( Chore.created_by_id == user_id, Chore.type == ChoreTypeEnum.personal ) ) - # Get user's group IDs first user_groups_result = await db.execute( select(UserGroup.group_id).where(UserGroup.user_id == user_id) ) @@ -35,7 +33,6 @@ async def get_all_user_chores(db: AsyncSession, user_id: int) -> List[Chore]: all_chores = [] - # Execute personal chores query personal_result = await db.execute( personal_chores_query .options( @@ -48,7 +45,6 @@ async def get_all_user_chores(db: AsyncSession, user_id: int) -> List[Chore]: ) all_chores.extend(personal_result.scalars().all()) - # If user has groups, get all group chores in one query if user_group_ids: group_chores_result = await db.execute( select(Chore) @@ -76,12 +72,10 @@ async def create_chore( group_id: Optional[int] = None ) -> Chore: """Creates a new chore, either personal or within a specific group.""" - # Use the transaction pattern from the FastAPI strategy async with db.begin_nested() if db.in_transaction() else db.begin(): if chore_in.type == ChoreTypeEnum.group: if not group_id: raise ValueError("group_id is required for group chores") - # Validate group existence and user membership group = await get_group_by_id(db, group_id) if not group: raise GroupNotFoundError(group_id) @@ -97,14 +91,12 @@ async def create_chore( created_by_id=user_id, ) - # Specific check for custom frequency if chore_in.frequency == ChoreFrequencyEnum.custom and chore_in.custom_interval_days is None: raise ValueError("custom_interval_days must be set for custom frequency chores.") db.add(db_chore) - await db.flush() # Get the ID for the chore + await db.flush() - # Log history await create_chore_history_entry( db, chore_id=db_chore.id, @@ -115,7 +107,6 @@ async def create_chore( ) try: - # Load relationships for the response with eager loading result = await db.execute( select(Chore) .where(Chore.id == db_chore.id) @@ -221,10 +212,8 @@ async def update_chore( if not db_chore: raise ChoreNotFoundError(chore_id, group_id) - # Store original state for history original_data = {field: getattr(db_chore, field) for field in chore_in.model_dump(exclude_unset=True)} - # Check permissions if db_chore.type == ChoreTypeEnum.group: if not group_id: raise ValueError("group_id is required for group chores") @@ -232,7 +221,7 @@ async def update_chore( raise PermissionDeniedError(detail=f"User {user_id} not a member of group {group_id}") if db_chore.group_id != group_id: raise ChoreNotFoundError(chore_id, group_id) - else: # personal chore + else: if group_id: raise ValueError("group_id must be None for personal chores") if db_chore.created_by_id != user_id: @@ -240,7 +229,6 @@ async def update_chore( update_data = chore_in.model_dump(exclude_unset=True) - # Handle type change if 'type' in update_data: new_type = update_data['type'] if new_type == ChoreTypeEnum.group and not group_id: @@ -275,7 +263,6 @@ async def update_chore( if db_chore.frequency == ChoreFrequencyEnum.custom and db_chore.custom_interval_days is None: raise ValueError("custom_interval_days must be set for custom frequency chores.") - # Log history for changes changes = {} for field, old_value in original_data.items(): new_value = getattr(db_chore, field) @@ -293,7 +280,7 @@ async def update_chore( ) try: - await db.flush() # Flush changes within the transaction + await db.flush() result = await db.execute( select(Chore) .where(Chore.id == db_chore.id) @@ -322,7 +309,6 @@ async def delete_chore( if not db_chore: raise ChoreNotFoundError(chore_id, group_id) - # Log history before deleting await create_chore_history_entry( db, chore_id=chore_id, @@ -332,7 +318,6 @@ async def delete_chore( event_data={"chore_name": db_chore.name} ) - # Check permissions if db_chore.type == ChoreTypeEnum.group: if not group_id: raise ValueError("group_id is required for group chores") @@ -348,7 +333,7 @@ async def delete_chore( try: await db.delete(db_chore) - await db.flush() # Ensure deletion is processed within the transaction + await db.flush() return True except Exception as e: logger.error(f"Error deleting chore {chore_id}: {e}", exc_info=True) @@ -363,27 +348,23 @@ async def create_chore_assignment( ) -> ChoreAssignment: """Creates a new chore assignment. User must be able to manage the chore.""" async with db.begin_nested() if db.in_transaction() else db.begin(): - # Get the chore and validate permissions chore = await get_chore_by_id(db, assignment_in.chore_id) if not chore: raise ChoreNotFoundError(chore_id=assignment_in.chore_id) - # Check permissions to assign this chore if chore.type == ChoreTypeEnum.personal: if chore.created_by_id != user_id: raise PermissionDeniedError(detail="Only the creator can assign personal chores") else: # group chore if not await is_user_member(db, chore.group_id, user_id): raise PermissionDeniedError(detail=f"User {user_id} not a member of group {chore.group_id}") - # For group chores, check if assignee is also a group member if not await is_user_member(db, chore.group_id, assignment_in.assigned_to_user_id): raise PermissionDeniedError(detail=f"Cannot assign chore to user {assignment_in.assigned_to_user_id} who is not a group member") db_assignment = ChoreAssignment(**assignment_in.model_dump(exclude_unset=True)) db.add(db_assignment) - await db.flush() # Get the ID for the assignment + await db.flush() - # Log history await create_assignment_history_entry( db, assignment_id=db_assignment.id, @@ -393,7 +374,6 @@ async def create_chore_assignment( ) try: - # Load relationships for the response result = await db.execute( select(ChoreAssignment) .where(ChoreAssignment.id == db_assignment.id) @@ -450,12 +430,11 @@ async def get_chore_assignments( chore = await get_chore_by_id(db, chore_id) if not chore: raise ChoreNotFoundError(chore_id=chore_id) - - # Check permissions + if chore.type == ChoreTypeEnum.personal: if chore.created_by_id != user_id: raise PermissionDeniedError(detail="Can only view assignments for own personal chores") - else: # group chore + else: if not await is_user_member(db, chore.group_id, user_id): raise PermissionDeniedError(detail=f"User {user_id} not a member of group {chore.group_id}") @@ -487,11 +466,10 @@ async def update_chore_assignment( if not chore: raise ChoreNotFoundError(chore_id=db_assignment.chore_id) - # Check permissions - only assignee can complete, but chore managers can reschedule can_manage = False if chore.type == ChoreTypeEnum.personal: can_manage = chore.created_by_id == user_id - else: # group chore + else: can_manage = await is_user_member(db, chore.group_id, user_id) can_complete = db_assignment.assigned_to_user_id == user_id @@ -501,7 +479,6 @@ async def update_chore_assignment( original_assignee = db_assignment.assigned_to_user_id original_due_date = db_assignment.due_date - # Check specific permissions for different updates if 'is_complete' in update_data and not can_complete: raise PermissionDeniedError(detail="Only the assignee can mark assignments as complete") @@ -515,7 +492,6 @@ async def update_chore_assignment( raise PermissionDeniedError(detail="Only chore managers can reassign assignments") await create_assignment_history_entry(db, assignment_id=assignment_id, changed_by_user_id=user_id, event_type=ChoreHistoryEventTypeEnum.REASSIGNED, event_data={"old": original_assignee, "new": update_data['assigned_to_user_id']}) - # Handle completion logic if 'is_complete' in update_data: if update_data['is_complete'] and not db_assignment.is_complete: update_data['completed_at'] = datetime.utcnow() @@ -531,13 +507,11 @@ async def update_chore_assignment( update_data['completed_at'] = None await create_assignment_history_entry(db, assignment_id=assignment_id, changed_by_user_id=user_id, event_type=ChoreHistoryEventTypeEnum.REOPENED) - # Apply updates for field, value in update_data.items(): setattr(db_assignment, field, value) try: await db.flush() - # Load relationships for the response result = await db.execute( select(ChoreAssignment) .where(ChoreAssignment.id == db_assignment.id) @@ -563,7 +537,6 @@ async def delete_chore_assignment( if not db_assignment: raise ChoreNotFoundError(assignment_id=assignment_id) - # Log history before deleting await create_assignment_history_entry( db, assignment_id=assignment_id, @@ -572,22 +545,20 @@ async def delete_chore_assignment( event_data={"unassigned_user_id": db_assignment.assigned_to_user_id} ) - # Load the chore for permission checking chore = await get_chore_by_id(db, db_assignment.chore_id) if not chore: raise ChoreNotFoundError(chore_id=db_assignment.chore_id) - # Check permissions if chore.type == ChoreTypeEnum.personal: if chore.created_by_id != user_id: raise PermissionDeniedError(detail="Only the creator can delete personal chore assignments") - else: # group chore + else: if not await is_user_member(db, chore.group_id, user_id): raise PermissionDeniedError(detail=f"User {user_id} not a member of group {chore.group_id}") try: await db.delete(db_assignment) - await db.flush() # Ensure deletion is processed within the transaction + await db.flush() return True except Exception as e: logger.error(f"Error deleting chore assignment {assignment_id}: {e}", exc_info=True) diff --git a/be/app/crud/group.py b/be/app/crud/group.py index 054fa08..976eed8 100644 --- a/be/app/crud/group.py +++ b/be/app/crud/group.py @@ -1,15 +1,14 @@ -# app/crud/group.py from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from sqlalchemy.orm import selectinload # For eager loading members +from sqlalchemy.orm import selectinload from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError from typing import Optional, List from sqlalchemy import delete, func -import logging # Add logging import +import logging from app.models import User as UserModel, Group as GroupModel, UserGroup as UserGroupModel from app.schemas.group import GroupCreate -from app.models import UserRoleEnum # Import enum +from app.models import UserRoleEnum from app.core.exceptions import ( GroupOperationError, GroupNotFoundError, diff --git a/be/app/crud/history.py b/be/app/crud/history.py index e5a8012..eabd421 100644 --- a/be/app/crud/history.py +++ b/be/app/crud/history.py @@ -1,4 +1,3 @@ -# be/app/crud/history.py from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload @@ -76,7 +75,7 @@ async def get_group_chore_history(db: AsyncSession, group_id: int) -> List[Chore .where(ChoreHistory.group_id == group_id) .options( selectinload(ChoreHistory.changed_by_user), - selectinload(ChoreHistory.chore) # Also load chore info if available + selectinload(ChoreHistory.chore) ) .order_by(ChoreHistory.timestamp.desc()) ) diff --git a/be/app/crud/invite.py b/be/app/crud/invite.py index c42b359..b933cc4 100644 --- a/be/app/crud/invite.py +++ b/be/app/crud/invite.py @@ -1,26 +1,24 @@ -# app/crud/invite.py -import logging # Add logging import +import logging import secrets from datetime import datetime, timedelta, timezone from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from sqlalchemy.orm import selectinload # Ensure selectinload is imported -from sqlalchemy import delete # Import delete statement +from sqlalchemy.orm import selectinload +from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError from typing import Optional -from app.models import Invite as InviteModel, Group as GroupModel, User as UserModel # Import related models for selectinload +from app.models import Invite as InviteModel, Group as GroupModel, User as UserModel from app.core.exceptions import ( DatabaseConnectionError, DatabaseIntegrityError, DatabaseQueryError, DatabaseTransactionError, - InviteOperationError # Add new specific exception + InviteOperationError ) -logger = logging.getLogger(__name__) # Initialize logger +logger = logging.getLogger(__name__) -# Invite codes should be reasonably unique, but handle potential collision MAX_CODE_GENERATION_ATTEMPTS = 5 async def deactivate_all_active_invites_for_group(db: AsyncSession, group_id: int): @@ -35,15 +33,13 @@ async def deactivate_all_active_invites_for_group(db: AsyncSession, group_id: in active_invites = result.scalars().all() if not active_invites: - return # No active invites to deactivate + return for invite in active_invites: invite.is_active = False db.add(invite) - await db.flush() # Flush changes within this transaction block + await db.flush() - # await db.flush() # Removed: Rely on caller to flush/commit - # No explicit commit here, assuming it's part of a larger transaction or caller handles commit. except OperationalError as e: logger.error(f"Database connection error deactivating invites for group {group_id}: {str(e)}", exc_info=True) raise DatabaseConnectionError(f"DB connection error deactivating invites for group {group_id}: {str(e)}") @@ -51,12 +47,11 @@ async def deactivate_all_active_invites_for_group(db: AsyncSession, group_id: in logger.error(f"Unexpected SQLAlchemy error deactivating invites for group {group_id}: {str(e)}", exc_info=True) raise DatabaseTransactionError(f"DB transaction error deactivating invites for group {group_id}: {str(e)}") -async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 365 * 100) -> Optional[InviteModel]: # Default to 100 years +async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expires_in_days: int = 365 * 100) -> Optional[InviteModel]: """Creates a new invite code for a group, deactivating any existing active ones for that group first.""" try: async with db.begin_nested() if db.in_transaction() else db.begin(): - # Deactivate existing active invites for this group await deactivate_all_active_invites_for_group(db, group_id) expires_at = datetime.now(timezone.utc) + timedelta(days=expires_in_days) @@ -101,7 +96,7 @@ async def create_invite(db: AsyncSession, group_id: int, creator_id: int, expire raise InviteOperationError("Failed to load invite after creation and flush.") return loaded_invite - except InviteOperationError: # Already specific, re-raise + except InviteOperationError: raise except IntegrityError as e: logger.error(f"Database integrity error during invite creation for group {group_id}: {str(e)}", exc_info=True) @@ -121,13 +116,12 @@ async def get_active_invite_for_group(db: AsyncSession, group_id: int) -> Option select(InviteModel).where( InviteModel.group_id == group_id, InviteModel.is_active == True, - InviteModel.expires_at > now # Still respect expiry, even if very long + InviteModel.expires_at > now ) - .order_by(InviteModel.created_at.desc()) # Get the most recent one if multiple (should not happen) .limit(1) .options( - selectinload(InviteModel.group), # Eager load group - selectinload(InviteModel.creator) # Eager load creator + selectinload(InviteModel.group), + selectinload(InviteModel.creator) ) ) result = await db.execute(stmt) @@ -166,10 +160,9 @@ async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteMode try: async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: invite.is_active = False - db.add(invite) # Add to session to track change - await db.flush() # Persist is_active change + db.add(invite) + await db.flush() - # Re-fetch with relationships stmt = ( select(InviteModel) .where(InviteModel.id == invite.id) @@ -181,7 +174,7 @@ async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteMode result = await db.execute(stmt) updated_invite = result.scalar_one_or_none() - if updated_invite is None: # Should not happen as invite is passed in + if updated_invite is None: raise InviteOperationError("Failed to load invite after deactivation.") return updated_invite @@ -192,8 +185,3 @@ async def deactivate_invite(db: AsyncSession, invite: InviteModel) -> InviteMode logger.error(f"Unexpected SQLAlchemy error deactivating invite: {str(e)}", exc_info=True) raise DatabaseTransactionError(f"DB transaction error deactivating invite: {str(e)}") -# Ensure InviteOperationError is defined in app.core.exceptions -# Example: class InviteOperationError(AppException): pass - -# Optional: Function to periodically delete old, inactive invites -# async def cleanup_old_invites(db: AsyncSession, older_than_days: int = 30): ... \ No newline at end of file diff --git a/be/app/crud/item.py b/be/app/crud/item.py index c1d9183..6045043 100644 --- a/be/app/crud/item.py +++ b/be/app/crud/item.py @@ -1,15 +1,14 @@ -# app/crud/item.py from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from sqlalchemy.orm import selectinload # Ensure selectinload is imported -from sqlalchemy import delete as sql_delete, update as sql_update # Use aliases +from sqlalchemy.orm import selectinload +from sqlalchemy import delete as sql_delete, update as sql_update from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError from typing import Optional, List as PyList from datetime import datetime, timezone -import logging # Add logging import +import logging from sqlalchemy import func -from app.models import Item as ItemModel, User as UserModel # Import UserModel for type hints if needed for selectinload +from app.models import Item as ItemModel, User as UserModel from app.schemas.item import ItemCreate, ItemUpdate from app.core.exceptions import ( ItemNotFoundError, @@ -18,16 +17,15 @@ from app.core.exceptions import ( DatabaseQueryError, DatabaseTransactionError, ConflictError, - ItemOperationError # Add if specific item operation errors are needed + ItemOperationError ) -logger = logging.getLogger(__name__) # Initialize logger +logger = logging.getLogger(__name__) async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_id: int) -> ItemModel: """Creates a new item record for a specific list, setting its position.""" try: - async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: - # Get the current max position in the list + async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Start transaction max_pos_stmt = select(func.max(ItemModel.position)).where(ItemModel.list_id == list_id) max_pos_result = await db.execute(max_pos_stmt) max_pos = max_pos_result.scalar_one_or_none() or 0 @@ -38,26 +36,24 @@ async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_ list_id=list_id, added_by_id=user_id, is_complete=False, - position=max_pos + 1 # Set the new position + position=max_pos + 1 ) db.add(db_item) - await db.flush() # Assigns ID + await db.flush() - # Re-fetch with relationships stmt = ( select(ItemModel) .where(ItemModel.id == db_item.id) .options( selectinload(ItemModel.added_by_user), - selectinload(ItemModel.completed_by_user) # Will be None but loads relationship + selectinload(ItemModel.completed_by_user) ) ) result = await db.execute(stmt) loaded_item = result.scalar_one_or_none() if loaded_item is None: - # await transaction.rollback() # Redundant, context manager handles rollback on exception - raise ItemOperationError("Failed to load item after creation.") # Define ItemOperationError + raise ItemOperationError("Failed to load item after creation.") return loaded_item except IntegrityError as e: @@ -69,8 +65,6 @@ async def create_item(db: AsyncSession, item_in: ItemCreate, list_id: int, user_ except SQLAlchemyError as e: logger.error(f"Unexpected SQLAlchemy error during item creation: {str(e)}", exc_info=True) raise DatabaseTransactionError(f"Failed to create item: {str(e)}") - # Removed generic Exception block as SQLAlchemyError should cover DB issues, - # and context manager handles rollback. async def get_items_by_list_id(db: AsyncSession, list_id: int) -> PyList[ItemModel]: """Gets all items belonging to a specific list, ordered by creation time.""" @@ -100,7 +94,7 @@ async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]: .options( selectinload(ItemModel.added_by_user), selectinload(ItemModel.completed_by_user), - selectinload(ItemModel.list) # Often useful to get the parent list + selectinload(ItemModel.list) ) ) result = await db.execute(stmt) @@ -113,7 +107,7 @@ async def get_item_by_id(db: AsyncSession, item_id: int) -> Optional[ItemModel]: async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, user_id: int) -> ItemModel: """Updates an existing item record, checking for version conflicts and handling reordering.""" try: - async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: + async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Start transaction if item_db.version != item_in.version: raise ConflictError( f"Item '{item_db.name}' (ID: {item_db.id}) has been modified. " @@ -122,31 +116,23 @@ async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, update_data = item_in.model_dump(exclude_unset=True, exclude={'version'}) - # --- Handle Reordering --- if 'position' in update_data: - new_position = update_data.pop('position') # Remove from update_data to handle separately + new_position = update_data.pop('position') - # We need the full list to reorder, making sure it's loaded and ordered list_id = item_db.list_id stmt = select(ItemModel).where(ItemModel.list_id == list_id).order_by(ItemModel.position.asc(), ItemModel.created_at.asc()) result = await db.execute(stmt) items_in_list = result.scalars().all() - # Find the item to move item_to_move = next((it for it in items_in_list if it.id == item_db.id), None) if item_to_move: items_in_list.remove(item_to_move) - # Insert at the new position (adjust for 1-based index from frontend) - # Clamp position to be within bounds insert_pos = max(0, min(new_position - 1, len(items_in_list))) items_in_list.insert(insert_pos, item_to_move) - # Re-assign positions for i, item in enumerate(items_in_list): item.position = i + 1 - # --- End Handle Reordering --- - if 'is_complete' in update_data: if update_data['is_complete'] is True: if item_db.completed_by_id is None: @@ -158,10 +144,9 @@ async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, setattr(item_db, key, value) item_db.version += 1 - db.add(item_db) # Mark as dirty + db.add(item_db) await db.flush() - # Re-fetch with relationships stmt = ( select(ItemModel) .where(ItemModel.id == item_db.id) @@ -174,8 +159,7 @@ async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, result = await db.execute(stmt) updated_item = result.scalar_one_or_none() - if updated_item is None: # Should not happen - # Rollback will be handled by context manager on raise + if updated_item is None: raise ItemOperationError("Failed to load item after update.") return updated_item @@ -185,7 +169,7 @@ async def update_item(db: AsyncSession, item_db: ItemModel, item_in: ItemUpdate, except OperationalError as e: logger.error(f"Database connection error while updating item: {str(e)}", exc_info=True) raise DatabaseConnectionError(f"Database connection error while updating item: {str(e)}") - except ConflictError: # Re-raise ConflictError, rollback handled by context manager + except ConflictError: raise except SQLAlchemyError as e: logger.error(f"Unexpected SQLAlchemy error during item update: {str(e)}", exc_info=True) @@ -196,14 +180,9 @@ async def delete_item(db: AsyncSession, item_db: ItemModel) -> None: try: async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: await db.delete(item_db) - # await transaction.commit() # Removed - # No return needed for None except OperationalError as e: logger.error(f"Database connection error while deleting item: {str(e)}", exc_info=True) raise DatabaseConnectionError(f"Database connection error while deleting item: {str(e)}") except SQLAlchemyError as e: logger.error(f"Unexpected SQLAlchemy error while deleting item: {str(e)}", exc_info=True) - raise DatabaseTransactionError(f"Failed to delete item: {str(e)}") - -# Ensure ItemOperationError is defined in app.core.exceptions if used -# Example: class ItemOperationError(AppException): pass \ No newline at end of file + raise DatabaseTransactionError(f"Failed to delete item: {str(e)}") \ No newline at end of file diff --git a/be/app/crud/list.py b/be/app/crud/list.py index 0aa1dbb..d6814c1 100644 --- a/be/app/crud/list.py +++ b/be/app/crud/list.py @@ -1,11 +1,10 @@ -# app/crud/list.py from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload, joinedload from sqlalchemy import or_, and_, delete as sql_delete, func as sql_func, desc from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError from typing import Optional, List as PyList -import logging # Add logging import +import logging from app.schemas.list import ListStatus from app.models import List as ListModel, UserGroup as UserGroupModel, Item as ItemModel @@ -22,12 +21,12 @@ from app.core.exceptions import ( ListOperationError ) -logger = logging.getLogger(__name__) # Initialize logger +logger = logging.getLogger(__name__) async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> ListModel: """Creates a new list record.""" try: - async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: + async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Start transaction db_list = ListModel( name=list_in.name, description=list_in.description, @@ -36,16 +35,14 @@ async def create_list(db: AsyncSession, list_in: ListCreate, creator_id: int) -> is_complete=False ) db.add(db_list) - await db.flush() # Assigns ID + await db.flush() - # Re-fetch with relationships for the response stmt = ( select(ListModel) .where(ListModel.id == db_list.id) .options( selectinload(ListModel.creator), selectinload(ListModel.group) - # selectinload(ListModel.items) # Optionally add if items are always needed in response ) ) result = await db.execute(stmt) @@ -129,7 +126,7 @@ async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) """Updates an existing list record, checking for version conflicts.""" try: async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: - if list_db.version != list_in.version: # list_db here is the one passed in, pre-loaded by API layer + if list_db.version != list_in.version: raise ConflictError( f"List '{list_db.name}' (ID: {list_db.id}) has been modified. " f"Your version is {list_in.version}, current version is {list_db.version}. Please refresh." @@ -145,20 +142,18 @@ async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) db.add(list_db) # Add the already attached list_db to mark it dirty for the session await db.flush() - # Re-fetch with relationships for the response stmt = ( select(ListModel) .where(ListModel.id == list_db.id) .options( selectinload(ListModel.creator), selectinload(ListModel.group) - # selectinload(ListModel.items) # Optionally add if items are always needed in response ) ) result = await db.execute(stmt) updated_list = result.scalar_one_or_none() - if updated_list is None: # Should not happen + if updated_list is None: raise ListOperationError("Failed to load list after update.") return updated_list @@ -177,7 +172,7 @@ async def update_list(db: AsyncSession, list_db: ListModel, list_in: ListUpdate) async def delete_list(db: AsyncSession, list_db: ListModel) -> None: """Deletes a list record. Version check should be done by the caller (API endpoint).""" try: - async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: # Standardize transaction + async with db.begin_nested() if db.in_transaction() else db.begin() as transaction: await db.delete(list_db) except OperationalError as e: logger.error(f"Database connection error while deleting list: {str(e)}", exc_info=True) @@ -257,7 +252,6 @@ async def get_list_by_name_and_group( Used for conflict resolution when creating lists. """ try: - # Base query for the list itself base_query = select(ListModel).where(ListModel.name == name) if group_id is not None: @@ -265,7 +259,6 @@ async def get_list_by_name_and_group( else: base_query = base_query.where(ListModel.group_id.is_(None)) - # Add eager loading for common relationships base_query = base_query.options( selectinload(ListModel.creator), selectinload(ListModel.group) @@ -277,19 +270,17 @@ async def get_list_by_name_and_group( if not target_list: return None - # Permission check is_creator = target_list.created_by_id == user_id if is_creator: return target_list if target_list.group_id: - from app.crud.group import is_user_member # Assuming this is a quick check not needing its own transaction + from app.crud.group import is_user_member is_member_of_group = await is_user_member(db, group_id=target_list.group_id, user_id=user_id) if is_member_of_group: return target_list - # If not creator and (not a group list or not a member of the group list) return None except OperationalError as e: @@ -305,22 +296,17 @@ async def get_lists_statuses_by_ids(db: AsyncSession, list_ids: PyList[int], use if not list_ids: return [] - try: - # First, get the groups the user is a member of + try: group_ids_result = await db.execute( select(UserGroupModel.group_id).where(UserGroupModel.user_id == user_id) ) user_group_ids = group_ids_result.scalars().all() - # Build the permission logic permission_filter = or_( - # User is the creator of the list and_(ListModel.created_by_id == user_id, ListModel.group_id.is_(None)), - # List belongs to a group the user is a member of ListModel.group_id.in_(user_group_ids) ) - # Main query to get list data and item counts query = ( select( ListModel.id, @@ -340,11 +326,7 @@ async def get_lists_statuses_by_ids(db: AsyncSession, list_ids: PyList[int], use result = await db.execute(query) - # The result will be rows of (id, updated_at, item_count). - # We need to verify that all requested list_ids that the user *should* have access to are present. - # The filter in the query already handles permissions. - - return result.all() # Returns a list of Row objects with id, updated_at, item_count + return result.all() except OperationalError as e: raise DatabaseConnectionError(f"Failed to connect to database: {str(e)}") diff --git a/be/app/crud/schedule.py b/be/app/crud/schedule.py index a42e0dc..1c6110b 100644 --- a/be/app/crud/schedule.py +++ b/be/app/crud/schedule.py @@ -1,13 +1,10 @@ -# be/app/crud/schedule.py import logging from datetime import date, timedelta from typing import List from itertools import cycle - from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select - -from app.models import Chore, Group, User, ChoreAssignment, UserGroup, ChoreTypeEnum, ChoreHistoryEventTypeEnum +from app.models import Chore, ChoreAssignment, UserGroup, ChoreTypeEnum, ChoreHistoryEventTypeEnum from app.crud.group import get_group_by_id from app.crud.history import create_chore_history_entry from app.core.exceptions import GroupNotFoundError, ChoreOperationError @@ -20,7 +17,7 @@ async def generate_group_chore_schedule( group_id: int, start_date: date, end_date: date, - user_id: int, # The user initiating the action + user_id: int, member_ids: List[int] = None ) -> List[ChoreAssignment]: """ @@ -34,7 +31,6 @@ async def generate_group_chore_schedule( raise GroupNotFoundError(group_id) if not member_ids: - # If no members are specified, use all members from the group members_result = await db.execute( select(UserGroup.user_id).where(UserGroup.group_id == group_id) ) @@ -43,7 +39,6 @@ async def generate_group_chore_schedule( if not member_ids: raise ChoreOperationError("Cannot generate schedule with no members.") - # Fetch all chores belonging to this group chores_result = await db.execute( select(Chore).where(Chore.group_id == group_id, Chore.type == ChoreTypeEnum.group) ) @@ -58,16 +53,7 @@ async def generate_group_chore_schedule( current_date = start_date while current_date <= end_date: for chore in group_chores: - # Check if a chore is due on the current day based on its frequency - # This is a simplified check. A more robust system would use the chore's next_due_date - # and frequency to see if it falls on the current_date. - # For this implementation, we assume we generate assignments for ALL chores on ALL days - # in the range, which might not be desired. - # A better approach is needed here. Let's assume for now we just create assignments for each chore - # on its *next* due date if it falls within the range. - if start_date <= chore.next_due_date <= end_date: - # Check if an assignment for this chore on this due date already exists existing_assignment_result = await db.execute( select(ChoreAssignment.id) .where(ChoreAssignment.chore_id == chore.id, ChoreAssignment.due_date == chore.next_due_date) @@ -82,7 +68,7 @@ async def generate_group_chore_schedule( assignment = ChoreAssignment( chore_id=chore.id, assigned_to_user_id=assigned_to_user_id, - due_date=chore.next_due_date, # Assign on the chore's own next_due_date + due_date=chore.next_due_date, is_complete=False ) db.add(assignment) @@ -95,10 +81,9 @@ async def generate_group_chore_schedule( logger.info(f"No new assignments were generated for group {group_id} in the specified date range.") return [] - # Log a single group-level event for the schedule generation await create_chore_history_entry( db, - chore_id=None, # This is a group-level event + chore_id=None, group_id=group_id, changed_by_user_id=user_id, event_type=ChoreHistoryEventTypeEnum.SCHEDULE_GENERATED, @@ -112,8 +97,6 @@ async def generate_group_chore_schedule( await db.flush() - # Refresh assignments to load relationships if needed, although not strictly necessary - # as the objects are already in the session. for assign in new_assignments: await db.refresh(assign) diff --git a/be/app/crud/settlement.py b/be/app/crud/settlement.py index b339b64..f81d4cd 100644 --- a/be/app/crud/settlement.py +++ b/be/app/crud/settlement.py @@ -1,4 +1,3 @@ -# app/crud/settlement.py from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import selectinload, joinedload @@ -7,7 +6,7 @@ from sqlalchemy.exc import SQLAlchemyError, OperationalError, IntegrityError from decimal import Decimal, ROUND_HALF_UP from typing import List as PyList, Optional, Sequence from datetime import datetime, timezone -import logging # Add logging import +import logging from app.models import ( Settlement as SettlementModel, @@ -28,7 +27,7 @@ from app.core.exceptions import ( ConflictError ) -logger = logging.getLogger(__name__) # Initialize logger +logger = logging.getLogger(__name__) async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, current_user_id: int) -> SettlementModel: """Creates a new settlement record.""" @@ -49,13 +48,6 @@ async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, c if not group: raise GroupNotFoundError(settlement_in.group_id) - # Permission check example (can be in API layer too) - # if current_user_id not in [payer.id, payee.id]: - # is_member_stmt = select(UserGroupModel.id).where(UserGroupModel.group_id == group.id, UserGroupModel.user_id == current_user_id).limit(1) - # is_member_result = await db.execute(is_member_stmt) - # if not is_member_result.scalar_one_or_none(): - # raise InvalidOperationError("Settlement recorder must be part of the group or one of the parties.") - db_settlement = SettlementModel( group_id=settlement_in.group_id, paid_by_user_id=settlement_in.paid_by_user_id, @@ -68,7 +60,6 @@ async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, c db.add(db_settlement) await db.flush() - # Re-fetch with relationships stmt = ( select(SettlementModel) .where(SettlementModel.id == db_settlement.id) @@ -87,8 +78,6 @@ async def create_settlement(db: AsyncSession, settlement_in: SettlementCreate, c return loaded_settlement except (UserNotFoundError, GroupNotFoundError, InvalidOperationError) as e: - # These are validation errors, re-raise them. - # If a transaction was started, context manager handles rollback. raise except IntegrityError as e: logger.error(f"Database integrity error during settlement creation: {str(e)}", exc_info=True) @@ -115,10 +104,8 @@ async def get_settlement_by_id(db: AsyncSession, settlement_id: int) -> Optional ) return result.scalars().first() except OperationalError as e: - # Optional: logger.warning or info if needed for read operations raise DatabaseConnectionError(f"DB connection error fetching settlement: {str(e)}") except SQLAlchemyError as e: - # Optional: logger.warning or info if needed for read operations raise DatabaseQueryError(f"DB query error fetching settlement: {str(e)}") async def get_settlements_for_group(db: AsyncSession, group_id: int, skip: int = 0, limit: int = 100) -> Sequence[SettlementModel]: @@ -183,10 +170,6 @@ async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, se """ try: async with db.begin_nested() if db.in_transaction() else db.begin(): - # Ensure the settlement_db passed is managed by the current session if not already. - # This is usually true if fetched by an endpoint dependency using the same session. - # If not, `db.add(settlement_db)` might be needed before modification if it's detached. - if not hasattr(settlement_db, 'version') or not hasattr(settlement_in, 'version'): raise InvalidOperationError("Version field is missing in model or input for optimistic locking.") @@ -204,22 +187,14 @@ async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, se if field in allowed_to_update: setattr(settlement_db, field, value) updated_something = True - # Silently ignore fields not allowed to update or raise error: - # else: - # raise InvalidOperationError(f"Field '{field}' cannot be updated.") if not updated_something and not settlement_in.model_fields_set.intersection(allowed_to_update): - # No updatable fields were actually provided, or they didn't change - # Still, we might want to return the re-loaded settlement if version matched. pass settlement_db.version += 1 - settlement_db.updated_at = datetime.now(timezone.utc) # Ensure model has this field - - db.add(settlement_db) # Mark as dirty + settlement_db.updated_at = datetime.now(timezone.utc) await db.flush() - # Re-fetch with relationships stmt = ( select(SettlementModel) .where(SettlementModel.id == settlement_db.id) @@ -233,11 +208,11 @@ async def update_settlement(db: AsyncSession, settlement_db: SettlementModel, se result = await db.execute(stmt) updated_settlement = result.scalar_one_or_none() - if updated_settlement is None: # Should not happen + if updated_settlement is None: raise SettlementOperationError("Failed to load settlement after update.") return updated_settlement - except ConflictError as e: # ConflictError should be defined in exceptions + except ConflictError as e: raise except InvalidOperationError as e: raise @@ -261,13 +236,13 @@ async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, ex async with db.begin_nested() if db.in_transaction() else db.begin(): if expected_version is not None: if not hasattr(settlement_db, 'version') or settlement_db.version != expected_version: - raise ConflictError( # Make sure ConflictError is defined + raise ConflictError( f"Settlement (ID: {settlement_db.id}) cannot be deleted. " f"Expected version {expected_version} does not match current version {settlement_db.version}. Please refresh." ) await db.delete(settlement_db) - except ConflictError as e: # ConflictError should be defined + except ConflictError as e: raise except OperationalError as e: logger.error(f"Database connection error during settlement deletion: {str(e)}", exc_info=True) @@ -275,7 +250,3 @@ async def delete_settlement(db: AsyncSession, settlement_db: SettlementModel, ex except SQLAlchemyError as e: logger.error(f"Unexpected SQLAlchemy error during settlement deletion: {str(e)}", exc_info=True) raise DatabaseTransactionError(f"DB transaction error during settlement deletion: {str(e)}") - -# Ensure SettlementOperationError and ConflictError are defined in app.core.exceptions -# Example: class SettlementOperationError(AppException): pass -# Example: class ConflictError(AppException): status_code = 409 \ No newline at end of file diff --git a/be/app/crud/settlement_activity.py b/be/app/crud/settlement_activity.py index 753a767..6ea3ccf 100644 --- a/be/app/crud/settlement_activity.py +++ b/be/app/crud/settlement_activity.py @@ -14,9 +14,7 @@ from app.models import ( ExpenseSplitStatusEnum, ExpenseOverallStatusEnum, ) -# Placeholder for Pydantic schema - actual schema definition is a later step -# from app.schemas.settlement_activity import SettlementActivityCreate # Assuming this path -from pydantic import BaseModel # Using pydantic BaseModel directly for the placeholder +from pydantic import BaseModel class SettlementActivityCreatePlaceholder(BaseModel): @@ -26,8 +24,7 @@ class SettlementActivityCreatePlaceholder(BaseModel): paid_at: Optional[datetime] = None class Config: - orm_mode = True # Pydantic V1 style orm_mode - # from_attributes = True # Pydantic V2 style + orm_mode = True async def update_expense_split_status(db: AsyncSession, expense_split_id: int) -> Optional[ExpenseSplit]: @@ -35,7 +32,6 @@ async def update_expense_split_status(db: AsyncSession, expense_split_id: int) - Updates the status of an ExpenseSplit based on its settlement activities. Also updates the overall status of the parent Expense. """ - # Fetch the ExpenseSplit with its related settlement_activities and the parent expense result = await db.execute( select(ExpenseSplit) .options( @@ -47,18 +43,13 @@ async def update_expense_split_status(db: AsyncSession, expense_split_id: int) - expense_split = result.scalar_one_or_none() if not expense_split: - # Or raise an exception, depending on desired error handling return None - # Calculate total_paid from all settlement_activities for that split total_paid = sum(activity.amount_paid for activity in expense_split.settlement_activities) - total_paid = Decimal(total_paid).quantize(Decimal("0.01")) # Ensure two decimal places + total_paid = Decimal(total_paid).quantize(Decimal("0.01")) - # Compare total_paid with ExpenseSplit.owed_amount if total_paid >= expense_split.owed_amount: expense_split.status = ExpenseSplitStatusEnum.paid - # Set paid_at to the latest relevant SettlementActivity or current time - # For simplicity, let's find the latest paid_at from activities, or use now() latest_paid_at = None if expense_split.settlement_activities: latest_paid_at = max(act.paid_at for act in expense_split.settlement_activities if act.paid_at) @@ -66,13 +57,13 @@ async def update_expense_split_status(db: AsyncSession, expense_split_id: int) - expense_split.paid_at = latest_paid_at if latest_paid_at else datetime.now(timezone.utc) elif total_paid > 0: expense_split.status = ExpenseSplitStatusEnum.partially_paid - expense_split.paid_at = None # Clear paid_at if not fully paid + expense_split.paid_at = None else: # total_paid == 0 expense_split.status = ExpenseSplitStatusEnum.unpaid - expense_split.paid_at = None # Clear paid_at + expense_split.paid_at = None await db.flush() - await db.refresh(expense_split, attribute_names=['status', 'paid_at', 'expense']) # Refresh to get updated data and related expense + await db.refresh(expense_split, attribute_names=['status', 'paid_at', 'expense']) return expense_split @@ -81,18 +72,16 @@ async def update_expense_overall_status(db: AsyncSession, expense_id: int) -> Op """ Updates the overall_status of an Expense based on the status of its splits. """ - # Fetch the Expense with its related splits result = await db.execute( select(Expense).options(selectinload(Expense.splits)).where(Expense.id == expense_id) ) expense = result.scalar_one_or_none() if not expense: - # Or raise an exception return None - if not expense.splits: # No splits, should not happen for a valid expense but handle defensively - expense.overall_settlement_status = ExpenseOverallStatusEnum.unpaid # Or some other default/error state + if not expense.splits: + expense.overall_settlement_status = ExpenseOverallStatusEnum.unpaid await db.flush() await db.refresh(expense) return expense @@ -107,14 +96,14 @@ async def update_expense_overall_status(db: AsyncSession, expense_id: int) -> Op num_paid_splits += 1 elif split.status == ExpenseSplitStatusEnum.partially_paid: num_partially_paid_splits += 1 - else: # unpaid + else: num_unpaid_splits += 1 if num_paid_splits == num_splits: expense.overall_settlement_status = ExpenseOverallStatusEnum.paid elif num_unpaid_splits == num_splits: expense.overall_settlement_status = ExpenseOverallStatusEnum.unpaid - else: # Mix of paid, partially_paid, or unpaid but not all unpaid/paid + else: expense.overall_settlement_status = ExpenseOverallStatusEnum.partially_paid await db.flush() @@ -130,43 +119,33 @@ async def create_settlement_activity( """ Creates a new settlement activity, then updates the parent expense split and expense statuses. """ - # Validate ExpenseSplit split_result = await db.execute(select(ExpenseSplit).where(ExpenseSplit.id == settlement_activity_in.expense_split_id)) expense_split = split_result.scalar_one_or_none() if not expense_split: - # Consider raising an HTTPException in an API layer - return None # ExpenseSplit not found + return None - # Validate User (paid_by_user_id) user_result = await db.execute(select(User).where(User.id == settlement_activity_in.paid_by_user_id)) paid_by_user = user_result.scalar_one_or_none() if not paid_by_user: return None # User not found - # Create SettlementActivity instance db_settlement_activity = SettlementActivity( expense_split_id=settlement_activity_in.expense_split_id, paid_by_user_id=settlement_activity_in.paid_by_user_id, amount_paid=settlement_activity_in.amount_paid, paid_at=settlement_activity_in.paid_at if settlement_activity_in.paid_at else datetime.now(timezone.utc), - created_by_user_id=current_user_id # The user recording the activity + created_by_user_id=current_user_id ) db.add(db_settlement_activity) - await db.flush() # Flush to get the ID for db_settlement_activity + await db.flush() # Update statuses updated_split = await update_expense_split_status(db, expense_split_id=db_settlement_activity.expense_split_id) if updated_split and updated_split.expense_id: await update_expense_overall_status(db, expense_id=updated_split.expense_id) else: - # This case implies update_expense_split_status returned None or expense_id was missing. - # This could be a problem, consider logging or raising an error. - # For now, the transaction would roll back if an exception is raised. - # If not raising, the overall status update might be skipped. - pass # Or handle error - - await db.refresh(db_settlement_activity, attribute_names=['split', 'payer', 'creator']) # Refresh to load relationships + pass return db_settlement_activity @@ -180,9 +159,9 @@ async def get_settlement_activity_by_id( result = await db.execute( select(SettlementActivity) .options( - selectinload(SettlementActivity.split).selectinload(ExpenseSplit.expense), # Load split and its parent expense - selectinload(SettlementActivity.payer), # Load the user who paid - selectinload(SettlementActivity.creator) # Load the user who created the record + selectinload(SettlementActivity.split).selectinload(ExpenseSplit.expense), + selectinload(SettlementActivity.payer), + selectinload(SettlementActivity.creator) ) .where(SettlementActivity.id == settlement_activity_id) ) @@ -199,8 +178,8 @@ async def get_settlement_activities_for_split( select(SettlementActivity) .where(SettlementActivity.expense_split_id == expense_split_id) .options( - selectinload(SettlementActivity.payer), # Load the user who paid - selectinload(SettlementActivity.creator) # Load the user who created the record + selectinload(SettlementActivity.payer), + selectinload(SettlementActivity.creator) ) .order_by(SettlementActivity.paid_at.desc(), SettlementActivity.created_at.desc()) .offset(skip) diff --git a/be/app/crud/user.py b/be/app/crud/user.py index 0e8b95e..545bfea 100644 --- a/be/app/crud/user.py +++ b/be/app/crud/user.py @@ -1,12 +1,11 @@ -# app/crud/user.py from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from sqlalchemy.orm import selectinload # Ensure selectinload is imported +from sqlalchemy.orm import selectinload from sqlalchemy.exc import SQLAlchemyError, IntegrityError, OperationalError from typing import Optional -import logging # Add logging import +import logging -from app.models import User as UserModel, UserGroup as UserGroupModel, Group as GroupModel # Import related models for selectinload +from app.models import User as UserModel, UserGroup as UserGroupModel from app.schemas.user import UserCreate from app.core.security import hash_password from app.core.exceptions import ( @@ -16,23 +15,19 @@ from app.core.exceptions import ( DatabaseIntegrityError, DatabaseQueryError, DatabaseTransactionError, - UserOperationError # Add if specific user operation errors are needed + UserOperationError ) -logger = logging.getLogger(__name__) # Initialize logger +logger = logging.getLogger(__name__) async def get_user_by_email(db: AsyncSession, email: str) -> Optional[UserModel]: """Fetches a user from the database by email, with common relationships.""" try: - # db.begin() is not strictly necessary for a single read, but ensures atomicity if multiple reads were added. - # For a single select, it can be omitted if preferred, session handles connection. stmt = ( select(UserModel) .filter(UserModel.email == email) .options( - selectinload(UserModel.group_associations).selectinload(UserGroupModel.group), # Groups user is member of - selectinload(UserModel.created_groups) # Groups user created - # Add other relationships as needed by UserPublic schema + selectinload(UserModel.group_associations).selectinload(UserGroupModel.group), ) ) result = await db.execute(stmt) @@ -51,27 +46,25 @@ async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel: _hashed_password = hash_password(user_in.password) db_user = UserModel( email=user_in.email, - hashed_password=_hashed_password, # Field name in model is hashed_password + hashed_password=_hashed_password, name=user_in.name ) db.add(db_user) - await db.flush() # Flush to get DB-generated values like ID + await db.flush() - # Re-fetch with relationships stmt = ( select(UserModel) .where(UserModel.id == db_user.id) .options( selectinload(UserModel.group_associations).selectinload(UserGroupModel.group), selectinload(UserModel.created_groups) - # Add other relationships as needed by UserPublic schema ) ) result = await db.execute(stmt) loaded_user = result.scalar_one_or_none() if loaded_user is None: - raise UserOperationError("Failed to load user after creation.") # Define UserOperationError + raise UserOperationError("Failed to load user after creation.") return loaded_user except IntegrityError as e: @@ -84,7 +77,4 @@ async def create_user(db: AsyncSession, user_in: UserCreate) -> UserModel: raise DatabaseConnectionError(f"Database connection error during user creation: {str(e)}") except SQLAlchemyError as e: logger.error(f"Unexpected SQLAlchemy error during user creation for email '{user_in.email}': {str(e)}", exc_info=True) - raise DatabaseTransactionError(f"Failed to create user due to other DB error: {str(e)}") - -# Ensure UserOperationError is defined in app.core.exceptions if used -# Example: class UserOperationError(AppException): pass \ No newline at end of file + raise DatabaseTransactionError(f"Failed to create user due to other DB error: {str(e)}") \ No newline at end of file diff --git a/be/app/database.py b/be/app/database.py index 9fc9105..2af3033 100644 --- a/be/app/database.py +++ b/be/app/database.py @@ -1,24 +1,18 @@ -# app/database.py from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker, declarative_base from app.config import settings -# Ensure DATABASE_URL is set before proceeding if not settings.DATABASE_URL: raise ValueError("DATABASE_URL is not configured in settings.") -# Create the SQLAlchemy async engine -# pool_recycle=3600 helps prevent stale connections on some DBs engine = create_async_engine( settings.DATABASE_URL, - echo=False, # Disable SQL query logging for production (use DEBUG log level to enable) - future=True, # Use SQLAlchemy 2.0 style features - pool_recycle=3600, # Optional: recycle connections after 1 hour - pool_pre_ping=True # Add this line to ensure connections are live + echo=False, + future=True, + pool_recycle=3600, + pool_pre_ping=True ) -# Create a configured "Session" class -# expire_on_commit=False prevents attributes from expiring after commit AsyncSessionLocal = sessionmaker( bind=engine, class_=AsyncSession, @@ -27,10 +21,8 @@ AsyncSessionLocal = sessionmaker( autocommit=False, ) -# Base class for our ORM models Base = declarative_base() -# Dependency to get DB session in path operations async def get_session() -> AsyncSession: # type: ignore """ Dependency function that yields an AsyncSession for read-only operations. @@ -38,7 +30,6 @@ async def get_session() -> AsyncSession: # type: ignore """ async with AsyncSessionLocal() as session: yield session - # The 'async with' block handles session.close() automatically. async def get_transactional_session() -> AsyncSession: # type: ignore """ @@ -51,7 +42,5 @@ async def get_transactional_session() -> AsyncSession: # type: ignore async with AsyncSessionLocal() as session: async with session.begin(): yield session - # Transaction is automatically committed on success or rolled back on exception -# Alias for backward compatibility get_db = get_session \ No newline at end of file diff --git a/be/app/db/session.py b/be/app/db/session.py index 959b962..2232e8c 100644 --- a/be/app/db/session.py +++ b/be/app/db/session.py @@ -1,4 +1,2 @@ from app.database import AsyncSessionLocal - -# Export the async session factory async_session = AsyncSessionLocal \ No newline at end of file diff --git a/be/app/jobs/recurring_expenses.py b/be/app/jobs/recurring_expenses.py index 96f9026..c0328a9 100644 --- a/be/app/jobs/recurring_expenses.py +++ b/be/app/jobs/recurring_expenses.py @@ -15,18 +15,15 @@ async def generate_recurring_expenses(db: AsyncSession) -> None: Should be run daily to check for and create new recurring expenses. """ try: - # Get all active recurring expenses that need to be generated now = datetime.utcnow() query = select(Expense).join(RecurrencePattern).where( and_( Expense.is_recurring == True, Expense.next_occurrence <= now, - # Check if we haven't reached max occurrences ( (RecurrencePattern.max_occurrences == None) | (RecurrencePattern.max_occurrences > 0) ), - # Check if we haven't reached end date ( (RecurrencePattern.end_date == None) | (RecurrencePattern.end_date > now) @@ -54,12 +51,10 @@ async def _generate_next_occurrence(db: AsyncSession, expense: Expense) -> None: if not pattern: return - # Calculate next occurrence date next_date = _calculate_next_occurrence(expense.next_occurrence, pattern) if not next_date: return - # Create new expense based on template new_expense = ExpenseCreate( description=expense.description, total_amount=expense.total_amount, @@ -70,14 +65,12 @@ async def _generate_next_occurrence(db: AsyncSession, expense: Expense) -> None: group_id=expense.group_id, item_id=expense.item_id, paid_by_user_id=expense.paid_by_user_id, - is_recurring=False, # Generated expenses are not recurring - splits_in=None # Will be generated based on split_type + is_recurring=False, + splits_in=None ) - # Create the new expense created_expense = await create_expense(db, new_expense, expense.created_by_user_id) - # Update the original expense expense.last_occurrence = next_date expense.next_occurrence = _calculate_next_occurrence(next_date, pattern) @@ -98,7 +91,6 @@ def _calculate_next_occurrence(current_date: datetime, pattern: RecurrencePatter if not pattern.days_of_week: return current_date + timedelta(weeks=pattern.interval) - # Find next day of week current_weekday = current_date.weekday() next_weekday = min((d for d in pattern.days_of_week if d > current_weekday), default=min(pattern.days_of_week)) @@ -108,7 +100,6 @@ def _calculate_next_occurrence(current_date: datetime, pattern: RecurrencePatter return current_date + timedelta(days=days_ahead) elif pattern.type == 'monthly': - # Add months to current date year = current_date.year + (current_date.month + pattern.interval - 1) // 12 month = (current_date.month + pattern.interval - 1) % 12 + 1 return current_date.replace(year=year, month=month) diff --git a/be/app/main.py b/be/app/main.py index 139ff6f..5d88623 100644 --- a/be/app/main.py +++ b/be/app/main.py @@ -1,60 +1,36 @@ -# app/main.py import logging -import uvicorn -from fastapi import FastAPI, HTTPException, Depends, status, Request +from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.sessions import SessionMiddleware import sentry_sdk from sentry_sdk.integrations.fastapi import FastApiIntegration -from fastapi_users.authentication import JWTStrategy -from pydantic import BaseModel -from jose import jwt, JWTError -from sqlalchemy.ext.asyncio import AsyncEngine -from alembic.config import Config -from alembic import command import os import sys - from app.api.api_router import api_router from app.config import settings from app.core.api_config import API_METADATA, API_TAGS -from app.auth import fastapi_users, auth_backend, get_refresh_jwt_strategy, get_jwt_strategy -from app.models import User -from app.api.auth.oauth import router as oauth_router +from app.auth import fastapi_users, auth_backend from app.schemas.user import UserPublic, UserCreate, UserUpdate from app.core.scheduler import init_scheduler, shutdown_scheduler -from app.database import get_session -from sqlalchemy import select -# Response model for refresh endpoint -class RefreshResponse(BaseModel): - access_token: str - refresh_token: str - token_type: str = "bearer" - -# Initialize Sentry only if DSN is provided if settings.SENTRY_DSN: sentry_sdk.init( dsn=settings.SENTRY_DSN, integrations=[ FastApiIntegration(), ], - # Adjust traces_sample_rate for production traces_sample_rate=0.1 if settings.is_production else 1.0, environment=settings.ENVIRONMENT, - # Enable PII data only in development send_default_pii=not settings.is_production ) -# --- Logging Setup --- logging.basicConfig( level=getattr(logging, settings.LOG_LEVEL), format=settings.LOG_FORMAT ) logger = logging.getLogger(__name__) -# --- FastAPI App Instance --- -# Create API metadata with environment-dependent settings + api_metadata = { **API_METADATA, "docs_url": settings.docs_url, @@ -67,13 +43,11 @@ app = FastAPI( openapi_tags=API_TAGS ) -# Add session middleware for OAuth app.add_middleware( SessionMiddleware, secret_key=settings.SESSION_SECRET_KEY ) -# --- CORS Middleware --- app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins_list, @@ -82,82 +56,7 @@ app.add_middleware( allow_headers=["*"], expose_headers=["*"] ) -# --- End CORS Middleware --- -# Refresh token endpoint -@app.post("/auth/jwt/refresh", response_model=RefreshResponse, tags=["auth"]) -async def refresh_jwt_token( - request: Request, - refresh_strategy: JWTStrategy = Depends(get_refresh_jwt_strategy), - access_strategy: JWTStrategy = Depends(get_jwt_strategy), -): - """ - Refresh access token using a valid refresh token. - Send refresh token in Authorization header: Bearer - """ - try: - # Get refresh token from Authorization header - authorization = request.headers.get("Authorization") - if not authorization or not authorization.startswith("Bearer "): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Refresh token missing or invalid format", - headers={"WWW-Authenticate": "Bearer"}, - ) - - refresh_token = authorization.split(" ")[1] - - # Validate refresh token and get user data - try: - # Decode the refresh token to get the user identifier - payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=["HS256"]) - user_id = payload.get("sub") - if user_id is None: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid refresh token", - ) - except JWTError: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid refresh token", - ) - - # Get user from database - async with get_session() as session: - result = await session.execute(select(User).where(User.id == int(user_id))) - user = result.scalar_one_or_none() - - if not user or not user.is_active: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="User not found or inactive", - ) - - # Generate new tokens - new_access_token = await access_strategy.write_token(user) - new_refresh_token = await refresh_strategy.write_token(user) - - return RefreshResponse( - access_token=new_access_token, - refresh_token=new_refresh_token, - token_type="bearer" - ) - - except HTTPException: - raise - except Exception as e: - logger.error(f"Error refreshing token: {e}") - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid refresh token" - ) - -# --- Include API Routers --- -# Include OAuth routes first (no auth required) -app.include_router(oauth_router, prefix="/auth", tags=["auth"]) - -# Include FastAPI-Users routes app.include_router( fastapi_users.get_auth_router(auth_backend), prefix="/auth/jwt", @@ -184,11 +83,8 @@ app.include_router( tags=["users"], ) -# Include your API router app.include_router(api_router, prefix=settings.API_PREFIX) -# --- End Include API Routers --- -# Health check endpoint @app.get("/health", tags=["Health"]) async def health_check(): """ @@ -200,7 +96,6 @@ async def health_check(): "version": settings.API_VERSION } -# --- Root Endpoint (Optional - outside the main API structure) --- @app.get("/", tags=["Root"]) async def read_root(): """ @@ -213,21 +108,17 @@ async def read_root(): "environment": settings.ENVIRONMENT, "version": settings.API_VERSION } -# --- End Root Endpoint --- async def run_migrations(): """Run database migrations.""" try: logger.info("Running database migrations...") - # Get the absolute path to the alembic directory base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) alembic_path = os.path.join(base_path, 'alembic') - # Add alembic directory to Python path if alembic_path not in sys.path: sys.path.insert(0, alembic_path) - # Import and run migrations from migrations import run_migrations as run_db_migrations await run_db_migrations() @@ -240,11 +131,7 @@ async def run_migrations(): async def startup_event(): """Initialize services on startup.""" logger.info(f"Application startup in {settings.ENVIRONMENT} environment...") - - # Run database migrations # await run_migrations() - - # Initialize scheduler init_scheduler() logger.info("Application startup complete.") @@ -252,15 +139,5 @@ async def startup_event(): async def shutdown_event(): """Cleanup services on shutdown.""" logger.info("Application shutdown: Disconnecting from database...") - # await database.engine.dispose() # Close connection pool shutdown_scheduler() - logger.info("Application shutdown complete.") -# --- End Events --- - - -# --- Direct Run (for simple local testing if needed) --- -# It's better to use `uvicorn app.main:app --reload` from the terminal -# if __name__ == "__main__": -# logger.info("Starting Uvicorn server directly from main.py") -# uvicorn.run(app, host="0.0.0.0", port=8000) -# ------------------------------------------------------ \ No newline at end of file + logger.info("Application shutdown complete.") \ No newline at end of file diff --git a/be/app/models.py b/be/app/models.py index 6a4f513..31cf02d 100644 --- a/be/app/models.py +++ b/be/app/models.py @@ -1,4 +1,3 @@ -# app/models.py import enum import secrets from datetime import datetime, timedelta, timezone @@ -14,16 +13,14 @@ from sqlalchemy import ( UniqueConstraint, Index, DDL, - event, - delete, func, text as sa_text, - Text, # <-- Add Text for description - Numeric, # <-- Add Numeric for price + Text, + Numeric, CheckConstraint, - Date # Added Date for Chore model + Date ) -from sqlalchemy.orm import relationship, backref +from sqlalchemy.orm import relationship from sqlalchemy.dialects.postgresql import JSONB from .database import Base @@ -82,7 +79,6 @@ class ChoreHistoryEventTypeEnum(str, enum.Enum): UNASSIGNED = "unassigned" REASSIGNED = "reassigned" SCHEDULE_GENERATED = "schedule_generated" - # Add more specific events as needed DUE_DATE_CHANGED = "due_date_changed" DETAILS_CHANGED = "details_changed" @@ -103,34 +99,20 @@ class User(Base): created_groups = relationship("Group", back_populates="creator") group_associations = relationship("UserGroup", back_populates="user", cascade="all, delete-orphan") created_invites = relationship("Invite", back_populates="creator") - - # --- NEW Relationships for Lists/Items --- - created_lists = relationship("List", foreign_keys="List.created_by_id", back_populates="creator") # Link List.created_by_id -> User - added_items = relationship("Item", foreign_keys="Item.added_by_id", back_populates="added_by_user") # Link Item.added_by_id -> User - completed_items = relationship("Item", foreign_keys="Item.completed_by_id", back_populates="completed_by_user") # Link Item.completed_by_id -> User - # --- End NEW Relationships --- - - # --- Relationships for Cost Splitting --- + created_lists = relationship("List", foreign_keys="List.created_by_id", back_populates="creator") + added_items = relationship("Item", foreign_keys="Item.added_by_id", back_populates="added_by_user") + completed_items = relationship("Item", foreign_keys="Item.completed_by_id", back_populates="completed_by_user") expenses_paid = relationship("Expense", foreign_keys="Expense.paid_by_user_id", back_populates="paid_by_user", cascade="all, delete-orphan") expenses_created = relationship("Expense", foreign_keys="Expense.created_by_user_id", back_populates="created_by_user", cascade="all, delete-orphan") expense_splits = relationship("ExpenseSplit", foreign_keys="ExpenseSplit.user_id", back_populates="user", cascade="all, delete-orphan") settlements_made = relationship("Settlement", foreign_keys="Settlement.paid_by_user_id", back_populates="payer", cascade="all, delete-orphan") settlements_received = relationship("Settlement", foreign_keys="Settlement.paid_to_user_id", back_populates="payee", cascade="all, delete-orphan") settlements_created = relationship("Settlement", foreign_keys="Settlement.created_by_user_id", back_populates="created_by_user", cascade="all, delete-orphan") - # --- End Relationships for Cost Splitting --- - - # --- Relationships for Chores --- created_chores = relationship("Chore", foreign_keys="[Chore.created_by_id]", back_populates="creator") assigned_chores = relationship("ChoreAssignment", back_populates="assigned_user", cascade="all, delete-orphan") - # --- End Relationships for Chores --- - - # --- History Relationships --- chore_history_entries = relationship("ChoreHistory", back_populates="changed_by_user", cascade="all, delete-orphan") assignment_history_entries = relationship("ChoreAssignmentHistory", back_populates="changed_by_user", cascade="all, delete-orphan") - # --- End History Relationships --- - -# --- Group Model --- class Group(Base): __tablename__ = "groups" @@ -139,30 +121,16 @@ class Group(Base): created_by_id = Column(Integer, ForeignKey("users.id"), nullable=False) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - # --- Relationships --- creator = relationship("User", back_populates="created_groups") member_associations = relationship("UserGroup", back_populates="group", cascade="all, delete-orphan") invites = relationship("Invite", back_populates="group", cascade="all, delete-orphan") - # --- NEW Relationship for Lists --- - lists = relationship("List", back_populates="group", cascade="all, delete-orphan") # Link List.group_id -> Group - # --- End NEW Relationship --- - - # --- Relationships for Cost Splitting --- + lists = relationship("List", back_populates="group", cascade="all, delete-orphan") expenses = relationship("Expense", foreign_keys="Expense.group_id", back_populates="group", cascade="all, delete-orphan") settlements = relationship("Settlement", foreign_keys="Settlement.group_id", back_populates="group", cascade="all, delete-orphan") - # --- End Relationships for Cost Splitting --- - - # --- Relationship for Chores --- chores = relationship("Chore", back_populates="group", cascade="all, delete-orphan") - # --- End Relationship for Chores --- - - # --- History Relationships --- chore_history = relationship("ChoreHistory", back_populates="group", cascade="all, delete-orphan") - # --- End History Relationships --- - -# --- UserGroup Association Model --- class UserGroup(Base): __tablename__ = "user_groups" __table_args__ = (UniqueConstraint('user_id', 'group_id', name='uq_user_group'),) @@ -176,8 +144,6 @@ class UserGroup(Base): user = relationship("User", back_populates="group_associations") group = relationship("Group", back_populates="member_associations") - -# --- Invite Model --- class Invite(Base): __tablename__ = "invites" __table_args__ = ( @@ -196,36 +162,30 @@ class Invite(Base): creator = relationship("User", back_populates="created_invites") -# === NEW: List Model === class List(Base): __tablename__ = "lists" id = Column(Integer, primary_key=True, index=True) name = Column(String, index=True, nullable=False) description = Column(Text, nullable=True) - created_by_id = Column(Integer, ForeignKey("users.id"), nullable=False) # Who created this list - group_id = Column(Integer, ForeignKey("groups.id"), nullable=True) # Which group it belongs to (NULL if personal) + created_by_id = Column(Integer, ForeignKey("users.id"), nullable=False) + group_id = Column(Integer, ForeignKey("groups.id"), nullable=True) is_complete = Column(Boolean, default=False, nullable=False) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) version = Column(Integer, nullable=False, default=1, server_default='1') - # --- Relationships --- - creator = relationship("User", back_populates="created_lists") # Link to User.created_lists - group = relationship("Group", back_populates="lists") # Link to Group.lists + creator = relationship("User", back_populates="created_lists") + group = relationship("Group", back_populates="lists") items = relationship( "Item", back_populates="list", cascade="all, delete-orphan", - order_by="Item.position.asc(), Item.created_at.asc()" # Default order by position, then creation + order_by="Item.position.asc(), Item.created_at.asc()" ) - # --- Relationships for Cost Splitting --- expenses = relationship("Expense", foreign_keys="Expense.list_id", back_populates="list", cascade="all, delete-orphan") - # --- End Relationships for Cost Splitting --- - -# === NEW: Item Model === class Item(Base): __tablename__ = "items" __table_args__ = ( @@ -233,31 +193,24 @@ class Item(Base): ) id = Column(Integer, primary_key=True, index=True) - list_id = Column(Integer, ForeignKey("lists.id", ondelete="CASCADE"), nullable=False) # Belongs to which list + list_id = Column(Integer, ForeignKey("lists.id", ondelete="CASCADE"), nullable=False) name = Column(String, index=True, nullable=False) - quantity = Column(String, nullable=True) # Flexible quantity (e.g., "1", "2 lbs", "a bunch") + quantity = Column(String, nullable=True) is_complete = Column(Boolean, default=False, nullable=False) - price = Column(Numeric(10, 2), nullable=True) # For cost splitting later (e.g., 12345678.99) - position = Column(Integer, nullable=False, server_default='0') # For ordering - added_by_id = Column(Integer, ForeignKey("users.id"), nullable=False) # Who added this item - completed_by_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Who marked it complete + price = Column(Numeric(10, 2), nullable=True) + position = Column(Integer, nullable=False, server_default='0') + added_by_id = Column(Integer, ForeignKey("users.id"), nullable=False) + completed_by_id = Column(Integer, ForeignKey("users.id"), nullable=True) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) version = Column(Integer, nullable=False, default=1, server_default='1') # --- Relationships --- - list = relationship("List", back_populates="items") # Link to List.items - added_by_user = relationship("User", foreign_keys=[added_by_id], back_populates="added_items") # Link to User.added_items - completed_by_user = relationship("User", foreign_keys=[completed_by_id], back_populates="completed_items") # Link to User.completed_items - - # --- Relationships for Cost Splitting --- - # If an item directly results in an expense, or an expense can be tied to an item. - expenses = relationship("Expense", back_populates="item") # An item might have multiple associated expenses - # --- End Relationships for Cost Splitting --- - - -# === NEW Models for Advanced Cost Splitting === - + list = relationship("List", back_populates="items") + added_by_user = relationship("User", foreign_keys=[added_by_id], back_populates="added_items") + completed_by_user = relationship("User", foreign_keys=[completed_by_id], back_populates="completed_items") + expenses = relationship("Expense", back_populates="item") + class Expense(Base): __tablename__ = "expenses" @@ -268,7 +221,6 @@ class Expense(Base): expense_date = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) split_type = Column(SAEnum(SplitTypeEnum, name="splittypeenum", create_type=True), nullable=False) - # Foreign Keys list_id = Column(Integer, ForeignKey("lists.id"), nullable=True, index=True) group_id = Column(Integer, ForeignKey("groups.id"), nullable=True, index=True) item_id = Column(Integer, ForeignKey("items.id"), nullable=True) @@ -279,7 +231,6 @@ class Expense(Base): updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) version = Column(Integer, nullable=False, default=1, server_default='1') - # Relationships paid_by_user = relationship("User", foreign_keys=[paid_by_user_id], back_populates="expenses_paid") created_by_user = relationship("User", foreign_keys=[created_by_user_id], back_populates="expenses_created") list = relationship("List", foreign_keys=[list_id], back_populates="expenses") @@ -289,7 +240,6 @@ class Expense(Base): parent_expense = relationship("Expense", remote_side=[id], back_populates="child_expenses") child_expenses = relationship("Expense", back_populates="parent_expense") overall_settlement_status = Column(SAEnum(ExpenseOverallStatusEnum, name="expenseoverallstatusenum", create_type=True), nullable=False, server_default=ExpenseOverallStatusEnum.unpaid.value, default=ExpenseOverallStatusEnum.unpaid) - # --- Recurrence fields --- is_recurring = Column(Boolean, default=False, nullable=False) recurrence_pattern_id = Column(Integer, ForeignKey("recurrence_patterns.id"), nullable=True) recurrence_pattern = relationship("RecurrencePattern", back_populates="expenses", uselist=False) # One-to-one @@ -298,7 +248,6 @@ class Expense(Base): last_occurrence = Column(DateTime(timezone=True), nullable=True) __table_args__ = ( - # Ensure at least one context is provided CheckConstraint('(item_id IS NOT NULL) OR (list_id IS NOT NULL) OR (group_id IS NOT NULL)', name='chk_expense_context'), ) @@ -320,14 +269,12 @@ class ExpenseSplit(Base): created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) - # Relationships expense = relationship("Expense", back_populates="splits") user = relationship("User", foreign_keys=[user_id], back_populates="expense_splits") settlement_activities = relationship("SettlementActivity", back_populates="split", cascade="all, delete-orphan") - # New fields for tracking payment status status = Column(SAEnum(ExpenseSplitStatusEnum, name="expensesplitstatusenum", create_type=True), nullable=False, server_default=ExpenseSplitStatusEnum.unpaid.value, default=ExpenseSplitStatusEnum.unpaid) - paid_at = Column(DateTime(timezone=True), nullable=True) # Timestamp when the split was fully paid + paid_at = Column(DateTime(timezone=True), nullable=True) class Settlement(Base): __tablename__ = "settlements" @@ -345,33 +292,28 @@ class Settlement(Base): updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) version = Column(Integer, nullable=False, default=1, server_default='1') - # Relationships group = relationship("Group", foreign_keys=[group_id], back_populates="settlements") payer = relationship("User", foreign_keys=[paid_by_user_id], back_populates="settlements_made") payee = relationship("User", foreign_keys=[paid_to_user_id], back_populates="settlements_received") created_by_user = relationship("User", foreign_keys=[created_by_user_id], back_populates="settlements_created") __table_args__ = ( - # Ensure payer and payee are different users CheckConstraint('paid_by_user_id != paid_to_user_id', name='chk_settlement_different_users'), ) -# Potential future: PaymentMethod model, etc. - class SettlementActivity(Base): __tablename__ = "settlement_activities" id = Column(Integer, primary_key=True, index=True) expense_split_id = Column(Integer, ForeignKey("expense_splits.id"), nullable=False, index=True) - paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) # User who made this part of the payment + paid_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) paid_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) amount_paid = Column(Numeric(10, 2), nullable=False) - created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) # User who recorded this activity + created_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) - # --- Relationships --- split = relationship("ExpenseSplit", back_populates="settlement_activities") payer = relationship("User", foreign_keys=[paid_by_user_id], backref="made_settlement_activities") creator = relationship("User", foreign_keys=[created_by_user_id], backref="created_settlement_activities") @@ -395,15 +337,14 @@ class Chore(Base): created_by_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True) frequency = Column(SAEnum(ChoreFrequencyEnum, name="chorefrequencyenum", create_type=True), nullable=False) - custom_interval_days = Column(Integer, nullable=True) # Only if frequency is 'custom' + custom_interval_days = Column(Integer, nullable=True) - next_due_date = Column(Date, nullable=False) # Changed to Date + next_due_date = Column(Date, nullable=False) last_completed_at = Column(DateTime(timezone=True), nullable=True) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) - # --- Relationships --- group = relationship("Group", back_populates="chores") creator = relationship("User", back_populates="created_chores") assignments = relationship("ChoreAssignment", back_populates="chore", cascade="all, delete-orphan") @@ -418,14 +359,13 @@ class ChoreAssignment(Base): chore_id = Column(Integer, ForeignKey("chores.id", ondelete="CASCADE"), nullable=False, index=True) assigned_to_user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True) - due_date = Column(Date, nullable=False) # Specific due date for this instance, changed to Date + due_date = Column(Date, nullable=False) is_complete = Column(Boolean, default=False, nullable=False) completed_at = Column(DateTime(timezone=True), nullable=True) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) - # --- Relationships --- chore = relationship("Chore", back_populates="assignments") assigned_user = relationship("User", back_populates="assigned_chores") history = relationship("ChoreAssignmentHistory", back_populates="assignment", cascade="all, delete-orphan") @@ -437,21 +377,14 @@ class RecurrencePattern(Base): id = Column(Integer, primary_key=True, index=True) type = Column(SAEnum(RecurrenceTypeEnum, name="recurrencetypeenum", create_type=True), nullable=False) - interval = Column(Integer, default=1, nullable=False) # e.g., every 1 day, every 2 weeks - days_of_week = Column(String, nullable=True) # For weekly recurrences, e.g., "MON,TUE,FRI" - # day_of_month = Column(Integer, nullable=True) # For monthly on a specific day - # week_of_month = Column(Integer, nullable=True) # For monthly on a specific week (e.g., 2nd week) - # month_of_year = Column(Integer, nullable=True) # For yearly recurrences + interval = Column(Integer, default=1, nullable=False) + days_of_week = Column(String, nullable=True) end_date = Column(DateTime(timezone=True), nullable=True) max_occurrences = Column(Integer, nullable=True) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) - # Relationship back to Expenses that use this pattern (could be one-to-many if patterns are shared) - # However, the current CRUD implies one RecurrencePattern per Expense if recurring. - # If a pattern can be shared, this would be a one-to-many (RecurrencePattern to many Expenses). - # For now, assuming one-to-one as implied by current Expense.recurrence_pattern relationship setup. expenses = relationship("Expense", back_populates="recurrence_pattern") @@ -464,13 +397,12 @@ class ChoreHistory(Base): id = Column(Integer, primary_key=True, index=True) chore_id = Column(Integer, ForeignKey("chores.id", ondelete="CASCADE"), nullable=True, index=True) - group_id = Column(Integer, ForeignKey("groups.id", ondelete="CASCADE"), nullable=True, index=True) # For group-level events + group_id = Column(Integer, ForeignKey("groups.id", ondelete="CASCADE"), nullable=True, index=True) event_type = Column(SAEnum(ChoreHistoryEventTypeEnum, name="chorehistoryeventtypeenum", create_type=True), nullable=False) - event_data = Column(JSONB, nullable=True) # e.g., {'field': 'name', 'old': 'Old', 'new': 'New'} - changed_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=True) # Nullable if system-generated + event_data = Column(JSONB, nullable=True) + changed_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=True) timestamp = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - # --- Relationships --- chore = relationship("Chore", back_populates="history") group = relationship("Group", back_populates="chore_history") changed_by_user = relationship("User", back_populates="chore_history_entries") @@ -480,11 +412,10 @@ class ChoreAssignmentHistory(Base): id = Column(Integer, primary_key=True, index=True) assignment_id = Column(Integer, ForeignKey("chore_assignments.id", ondelete="CASCADE"), nullable=False, index=True) - event_type = Column(SAEnum(ChoreHistoryEventTypeEnum, name="chorehistoryeventtypeenum", create_type=True), nullable=False) # Reusing enum + event_type = Column(SAEnum(ChoreHistoryEventTypeEnum, name="chorehistoryeventtypeenum", create_type=True), nullable=False) event_data = Column(JSONB, nullable=True) changed_by_user_id = Column(Integer, ForeignKey("users.id"), nullable=True) timestamp = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - # --- Relationships --- assignment = relationship("ChoreAssignment", back_populates="history") changed_by_user = relationship("User", back_populates="assignment_history_entries") diff --git a/be/app/models/expense.py b/be/app/models/expense.py deleted file mode 100644 index e69de29..0000000 diff --git a/be/app/schemas/auth.py b/be/app/schemas/auth.py index c0c4fcb..928787e 100644 --- a/be/app/schemas/auth.py +++ b/be/app/schemas/auth.py @@ -1,13 +1,7 @@ -# app/schemas/auth.py -from pydantic import BaseModel, EmailStr +from pydantic import BaseModel from app.config import settings class Token(BaseModel): access_token: str - refresh_token: str # Added refresh token - token_type: str = settings.TOKEN_TYPE # Use configured token type - -# Optional: If you preferred not to use OAuth2PasswordRequestForm -# class UserLogin(BaseModel): -# email: EmailStr -# password: str \ No newline at end of file + refresh_token: str + token_type: str = settings.TOKEN_TYPE \ No newline at end of file diff --git a/be/app/schemas/chore.py b/be/app/schemas/chore.py index 3605164..5cf6f8f 100644 --- a/be/app/schemas/chore.py +++ b/be/app/schemas/chore.py @@ -1,18 +1,12 @@ from datetime import date, datetime from typing import Optional, List, Any from pydantic import BaseModel, ConfigDict, field_validator +from ..models import ChoreFrequencyEnum, ChoreTypeEnum, ChoreHistoryEventTypeEnum +from .user import UserPublic -# Assuming ChoreFrequencyEnum is imported from models -# Adjust the import path if necessary based on your project structure. -# e.g., from app.models import ChoreFrequencyEnum -from ..models import ChoreFrequencyEnum, ChoreTypeEnum, User as UserModel, ChoreHistoryEventTypeEnum # For UserPublic relation -from .user import UserPublic # For embedding user information - -# Forward declaration for circular dependencies class ChoreAssignmentPublic(BaseModel): pass -# History Schemas class ChoreHistoryPublic(BaseModel): id: int event_type: ChoreHistoryEventTypeEnum @@ -32,7 +26,6 @@ class ChoreAssignmentHistoryPublic(BaseModel): model_config = ConfigDict(from_attributes=True) -# Chore Schemas class ChoreBase(BaseModel): name: str description: Optional[str] = None diff --git a/be/app/schemas/cost.py b/be/app/schemas/cost.py index 49dc0a8..e9a7868 100644 --- a/be/app/schemas/cost.py +++ b/be/app/schemas/cost.py @@ -4,10 +4,10 @@ from decimal import Decimal class UserCostShare(BaseModel): user_id: int - user_identifier: str # Name or email - items_added_value: Decimal = Decimal("0.00") # Total value of items this user added - amount_due: Decimal # The user's share of the total cost (for equal split, this is total_cost / num_users) - balance: Decimal # items_added_value - amount_due + user_identifier: str + items_added_value: Decimal = Decimal("0.00") + amount_due: Decimal + balance: Decimal model_config = ConfigDict(from_attributes=True) @@ -23,19 +23,19 @@ class ListCostSummary(BaseModel): class UserBalanceDetail(BaseModel): user_id: int - user_identifier: str # Name or email + user_identifier: str total_paid_for_expenses: Decimal = Decimal("0.00") total_share_of_expenses: Decimal = Decimal("0.00") total_settlements_paid: Decimal = Decimal("0.00") total_settlements_received: Decimal = Decimal("0.00") - net_balance: Decimal = Decimal("0.00") # (paid_for_expenses + settlements_received) - (share_of_expenses + settlements_paid) + net_balance: Decimal = Decimal("0.00") model_config = ConfigDict(from_attributes=True) class SuggestedSettlement(BaseModel): from_user_id: int - from_user_identifier: str # Name or email of payer + from_user_identifier: str to_user_id: int - to_user_identifier: str # Name or email of payee + to_user_identifier: str amount: Decimal model_config = ConfigDict(from_attributes=True) @@ -45,11 +45,5 @@ class GroupBalanceSummary(BaseModel): overall_total_expenses: Decimal = Decimal("0.00") overall_total_settlements: Decimal = Decimal("0.00") user_balances: List[UserBalanceDetail] - # Optional: Could add a list of suggested settlements to zero out balances suggested_settlements: Optional[List[SuggestedSettlement]] = None - model_config = ConfigDict(from_attributes=True) - -# class SuggestedSettlement(BaseModel): -# from_user_id: int -# to_user_id: int -# amount: Decimal \ No newline at end of file + model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/be/app/schemas/expense.py b/be/app/schemas/expense.py index 6c24b43..7b26f47 100644 --- a/be/app/schemas/expense.py +++ b/be/app/schemas/expense.py @@ -1,19 +1,11 @@ -# app/schemas/expense.py from pydantic import BaseModel, ConfigDict, validator, Field -from typing import List, Optional, Dict, Any +from typing import List, Optional from decimal import Decimal from datetime import datetime +from app.models import SplitTypeEnum, ExpenseSplitStatusEnum, ExpenseOverallStatusEnum +from app.schemas.user import UserPublic +from app.schemas.settlement_activity import SettlementActivityPublic -# Assuming SplitTypeEnum is accessible here, e.g., from app.models or app.core.enums -# For now, let's redefine it or import it if models.py is parsable by Pydantic directly -# If it's from app.models, you might need to make app.models.SplitTypeEnum Pydantic-compatible or map it. -# For simplicity during schema definition, I'll redefine a string enum here. -# In a real setup, ensure this aligns with the SQLAlchemy enum in models.py. -from app.models import SplitTypeEnum, ExpenseSplitStatusEnum, ExpenseOverallStatusEnum # Try importing directly -from app.schemas.user import UserPublic # For user details in responses -from app.schemas.settlement_activity import SettlementActivityPublic # For settlement activities - -# --- ExpenseSplit Schemas --- class ExpenseSplitBase(BaseModel): user_id: int owed_amount: Decimal @@ -21,20 +13,19 @@ class ExpenseSplitBase(BaseModel): share_units: Optional[int] = None class ExpenseSplitCreate(ExpenseSplitBase): - pass # All fields from base are needed for creation + pass class ExpenseSplitPublic(ExpenseSplitBase): id: int expense_id: int - user: Optional[UserPublic] = None # If we want to nest user details + user: Optional[UserPublic] = None created_at: datetime updated_at: datetime - status: ExpenseSplitStatusEnum # New field - paid_at: Optional[datetime] = None # New field - settlement_activities: List[SettlementActivityPublic] = [] # New field + status: ExpenseSplitStatusEnum + paid_at: Optional[datetime] = None + settlement_activities: List[SettlementActivityPublic] = [] model_config = ConfigDict(from_attributes=True) -# --- Expense Schemas --- class RecurrencePatternBase(BaseModel): type: str = Field(..., description="Type of recurrence: daily, weekly, monthly, yearly") interval: int = Field(..., description="Interval of recurrence (e.g., every X days/weeks/months/years)") @@ -63,16 +54,13 @@ class ExpenseBase(BaseModel): expense_date: Optional[datetime] = None split_type: SplitTypeEnum list_id: Optional[int] = None - group_id: Optional[int] = None # Should be present if list_id is not, and vice-versa + group_id: Optional[int] = None item_id: Optional[int] = None paid_by_user_id: int is_recurring: bool = Field(False, description="Whether this is a recurring expense") recurrence_pattern: Optional[RecurrencePatternCreate] = Field(None, description="Recurrence pattern for recurring expenses") class ExpenseCreate(ExpenseBase): - # For EQUAL split, splits are generated. For others, they might be provided. - # This logic will be in the CRUD: if split_type is EXACT_AMOUNTS, PERCENTAGE, SHARES, - # then 'splits_in' should be provided. splits_in: Optional[List[ExpenseSplitCreate]] = None @validator('total_amount') @@ -81,8 +69,6 @@ class ExpenseCreate(ExpenseBase): raise ValueError('Total amount must be positive') return v - # Basic validation: if list_id is None, group_id must be provided. - # More complex cross-field validation might be needed. @validator('group_id', always=True) def check_list_or_group_id(cls, v, values): if values.get('list_id') is None and v is None: @@ -105,10 +91,8 @@ class ExpenseUpdate(BaseModel): split_type: Optional[SplitTypeEnum] = None list_id: Optional[int] = None group_id: Optional[int] = None - item_id: Optional[int] = None - # paid_by_user_id is usually not updatable directly to maintain integrity. - # Updating splits would be a more complex operation, potentially a separate endpoint or careful logic. - version: int # For optimistic locking + item_id: Optional[int] = None + version: int is_recurring: Optional[bool] = None recurrence_pattern: Optional[RecurrencePatternUpdate] = None next_occurrence: Optional[datetime] = None @@ -120,11 +104,8 @@ class ExpensePublic(ExpenseBase): version: int created_by_user_id: int splits: List[ExpenseSplitPublic] = [] - paid_by_user: Optional[UserPublic] = None # If nesting user details - overall_settlement_status: ExpenseOverallStatusEnum # New field - # list: Optional[ListPublic] # If nesting list details - # group: Optional[GroupPublic] # If nesting group details - # item: Optional[ItemPublic] # If nesting item details + paid_by_user: Optional[UserPublic] = None + overall_settlement_status: ExpenseOverallStatusEnum is_recurring: bool next_occurrence: Optional[datetime] last_occurrence: Optional[datetime] @@ -133,7 +114,6 @@ class ExpensePublic(ExpenseBase): generated_expenses: List['ExpensePublic'] = [] model_config = ConfigDict(from_attributes=True) -# --- Settlement Schemas --- class SettlementBase(BaseModel): group_id: int paid_by_user_id: int @@ -159,8 +139,7 @@ class SettlementUpdate(BaseModel): amount: Optional[Decimal] = None settlement_date: Optional[datetime] = None description: Optional[str] = None - # group_id, paid_by_user_id, paid_to_user_id are typically not updatable. - version: int # For optimistic locking + version: int class SettlementPublic(SettlementBase): id: int @@ -168,13 +147,4 @@ class SettlementPublic(SettlementBase): updated_at: datetime version: int created_by_user_id: int - # payer: Optional[UserPublic] # If we want to include payer details - # payee: Optional[UserPublic] # If we want to include payee details - # group: Optional[GroupPublic] # If we want to include group details - model_config = ConfigDict(from_attributes=True) - -# Placeholder for nested schemas (e.g., UserPublic) if needed -# from app.schemas.user import UserPublic -# from app.schemas.list import ListPublic -# from app.schemas.group import GroupPublic -# from app.schemas.item import ItemPublic \ No newline at end of file + model_config = ConfigDict(from_attributes=True) \ No newline at end of file diff --git a/be/app/schemas/group.py b/be/app/schemas/group.py index ea43806..9645447 100644 --- a/be/app/schemas/group.py +++ b/be/app/schemas/group.py @@ -1,22 +1,17 @@ -# app/schemas/group.py from pydantic import BaseModel, ConfigDict, computed_field from datetime import datetime, date -from typing import Optional, List +from typing import Optional, List +from .user import UserPublic +from .chore import ChoreHistoryPublic -from .user import UserPublic # Import UserPublic to represent members -from .chore import ChoreHistoryPublic # Import for history - -# Properties to receive via API on creation class GroupCreate(BaseModel): name: str -# New schema for generating a schedule class GroupScheduleGenerateRequest(BaseModel): start_date: date end_date: date - member_ids: Optional[List[int]] = None # Optional: if not provided, use all members + member_ids: Optional[List[int]] = None -# Properties to return to client class GroupPublic(BaseModel): id: int name: str @@ -34,7 +29,6 @@ class GroupPublic(BaseModel): model_config = ConfigDict(from_attributes=True) -# Properties for UserGroup association class UserGroupPublic(BaseModel): id: int user_id: int @@ -45,9 +39,4 @@ class UserGroupPublic(BaseModel): model_config = ConfigDict(from_attributes=True) -# Properties stored in DB (if needed, often GroupPublic is sufficient) -# class GroupInDB(GroupPublic): -# pass - -# We need to rebuild GroupPublic to resolve the forward reference to UserGroupPublic GroupPublic.model_rebuild() \ No newline at end of file diff --git a/be/app/schemas/health.py b/be/app/schemas/health.py index bbb00b7..511b711 100644 --- a/be/app/schemas/health.py +++ b/be/app/schemas/health.py @@ -1,4 +1,4 @@ -# app/schemas/health.py + from pydantic import BaseModel from app.config import settings @@ -6,5 +6,5 @@ class HealthStatus(BaseModel): """ Response model for the health check endpoint. """ - status: str = settings.HEALTH_STATUS_OK # Use configured default value + status: str = settings.HEALTH_STATUS_OK database: str \ No newline at end of file diff --git a/be/app/schemas/invite.py b/be/app/schemas/invite.py index 63dcb79..25f6436 100644 --- a/be/app/schemas/invite.py +++ b/be/app/schemas/invite.py @@ -1,12 +1,9 @@ -# app/schemas/invite.py from pydantic import BaseModel from datetime import datetime -# Properties to receive when accepting an invite class InviteAccept(BaseModel): code: str -# Properties to return when an invite is created class InviteCodePublic(BaseModel): code: str expires_at: datetime diff --git a/be/app/schemas/item.py b/be/app/schemas/item.py index 33d3348..ef3fe67 100644 --- a/be/app/schemas/item.py +++ b/be/app/schemas/item.py @@ -1,10 +1,8 @@ -# app/schemas/item.py from pydantic import BaseModel, ConfigDict from datetime import datetime from typing import Optional from decimal import Decimal -# Properties to return to client class ItemPublic(BaseModel): id: int list_id: int @@ -19,19 +17,14 @@ class ItemPublic(BaseModel): version: int model_config = ConfigDict(from_attributes=True) -# Properties to receive via API on creation class ItemCreate(BaseModel): name: str quantity: Optional[str] = None - # list_id will be from path param - # added_by_id will be from current_user -# Properties to receive via API on update class ItemUpdate(BaseModel): name: Optional[str] = None quantity: Optional[str] = None is_complete: Optional[bool] = None - price: Optional[Decimal] = None # Price added here for update - position: Optional[int] = None # For reordering - version: int - # completed_by_id will be set internally if is_complete is true \ No newline at end of file + price: Optional[Decimal] = None + position: Optional[int] = None + version: int \ No newline at end of file diff --git a/be/app/schemas/list.py b/be/app/schemas/list.py index b21e506..fe8da4c 100644 --- a/be/app/schemas/list.py +++ b/be/app/schemas/list.py @@ -1,25 +1,20 @@ -# app/schemas/list.py from pydantic import BaseModel, ConfigDict from datetime import datetime from typing import Optional, List -from .item import ItemPublic # Import item schema for nesting +from .item import ItemPublic -# Properties to receive via API on creation class ListCreate(BaseModel): name: str description: Optional[str] = None - group_id: Optional[int] = None # Optional for sharing + group_id: Optional[int] = None -# Properties to receive via API on update class ListUpdate(BaseModel): name: Optional[str] = None description: Optional[str] = None is_complete: Optional[bool] = None - version: int # Client must provide the version for updates - # Potentially add group_id update later if needed + version: int -# Base properties returned by API (common fields) class ListBase(BaseModel): id: int name: str @@ -29,17 +24,15 @@ class ListBase(BaseModel): is_complete: bool created_at: datetime updated_at: datetime - version: int # Include version in responses + version: int model_config = ConfigDict(from_attributes=True) -# Properties returned when listing lists (no items) class ListPublic(ListBase): - pass # Inherits all from ListBase + pass -# Properties returned for a single list detail (includes items) class ListDetail(ListBase): - items: List[ItemPublic] = [] # Include list of items + items: List[ItemPublic] = [] class ListStatus(BaseModel): updated_at: datetime diff --git a/be/app/schemas/message.py b/be/app/schemas/message.py index 04b9e0e..10606af 100644 --- a/be/app/schemas/message.py +++ b/be/app/schemas/message.py @@ -1,4 +1,3 @@ -# app/schemas/message.py from pydantic import BaseModel class Message(BaseModel): diff --git a/be/app/schemas/ocr.py b/be/app/schemas/ocr.py index 0b4bba9..5baad0b 100644 --- a/be/app/schemas/ocr.py +++ b/be/app/schemas/ocr.py @@ -1,6 +1,5 @@ -# app/schemas/ocr.py from pydantic import BaseModel from typing import List class OcrExtractResponse(BaseModel): - extracted_items: List[str] # A list of potential item names \ No newline at end of file + extracted_items: List[str] \ No newline at end of file diff --git a/be/app/schemas/settlement_activity.py b/be/app/schemas/settlement_activity.py index 2c2b021..704e832 100644 --- a/be/app/schemas/settlement_activity.py +++ b/be/app/schemas/settlement_activity.py @@ -3,7 +3,7 @@ from typing import Optional, List from decimal import Decimal from datetime import datetime -from app.schemas.user import UserPublic # Assuming UserPublic is defined here +from app.schemas.user import UserPublic class SettlementActivityBase(BaseModel): expense_split_id: int @@ -21,23 +21,13 @@ class SettlementActivityCreate(SettlementActivityBase): class SettlementActivityPublic(SettlementActivityBase): id: int - created_by_user_id: int # User who recorded this activity + created_by_user_id: int created_at: datetime updated_at: datetime - payer: Optional[UserPublic] = None # User who made this part of the payment - creator: Optional[UserPublic] = None # User who recorded this activity + payer: Optional[UserPublic] = None + creator: Optional[UserPublic] = None model_config = ConfigDict(from_attributes=True) -# Schema for updating a settlement activity (if needed in the future) -# class SettlementActivityUpdate(BaseModel): -# amount_paid: Optional[Decimal] = None -# paid_at: Optional[datetime] = None -# @field_validator('amount_paid') -# @classmethod -# def amount_must_be_positive_if_provided(cls, v: Optional[Decimal]) -> Optional[Decimal]: -# if v is not None and v <= Decimal("0"): -# raise ValueError("Amount paid must be a positive value.") -# return v diff --git a/be/app/schemas/user.py b/be/app/schemas/user.py index ed5332c..50d6133 100644 --- a/be/app/schemas/user.py +++ b/be/app/schemas/user.py @@ -1,14 +1,11 @@ -# app/schemas/user.py from pydantic import BaseModel, EmailStr, ConfigDict from datetime import datetime from typing import Optional -# Shared properties class UserBase(BaseModel): email: EmailStr name: Optional[str] = None -# Properties to receive via API on creation class UserCreate(UserBase): password: str @@ -22,26 +19,22 @@ class UserCreate(UserBase): "is_verified": False } -# Properties to receive via API on update class UserUpdate(UserBase): password: Optional[str] = None is_active: Optional[bool] = None is_superuser: Optional[bool] = None is_verified: Optional[bool] = None -# Properties stored in DB class UserInDBBase(UserBase): id: int password_hash: str created_at: datetime - model_config = ConfigDict(from_attributes=True) # Use orm_mode in Pydantic v1 + model_config = ConfigDict(from_attributes=True) -# Additional properties to return via API (excluding password) class UserPublic(UserBase): id: int created_at: datetime model_config = ConfigDict(from_attributes=True) -# Full user model including hashed password (for internal use/reading from DB) class User(UserInDBBase): pass \ No newline at end of file diff --git a/fe/src/config/api-config.ts b/fe/src/config/api-config.ts index 2fa49fe..e2db33e 100644 --- a/fe/src/config/api-config.ts +++ b/fe/src/config/api-config.ts @@ -33,6 +33,7 @@ export const API_ENDPOINTS = { BASE: '/lists', BY_ID: (id: string) => `/lists/${id}`, STATUS: (id: string) => `/lists/${id}/status`, + STATUSES: '/lists/statuses', ITEMS: (listId: string) => `/lists/${listId}/items`, ITEM: (listId: string, itemId: string) => `/lists/${listId}/items/${itemId}`, EXPENSES: (listId: string) => `/lists/${listId}/expenses`, diff --git a/fe/src/layouts/AuthLayout.vue b/fe/src/layouts/AuthLayout.vue index ad608b9..3cd2cb7 100644 --- a/fe/src/layouts/AuthLayout.vue +++ b/fe/src/layouts/AuthLayout.vue @@ -7,7 +7,6 @@ \ No newline at end of file diff --git a/fe/src/layouts/MainLayout.vue b/fe/src/layouts/MainLayout.vue index a6ad428..f0ea2be 100644 --- a/fe/src/layouts/MainLayout.vue +++ b/fe/src/layouts/MainLayout.vue @@ -4,9 +4,7 @@
mitlist
-
-
-
-
- +
@@ -81,32 +74,24 @@ -
- @@ -124,7 +109,6 @@ import CreateListModal from '@/components/CreateListModal.vue'; import CreateGroupModal from '@/components/CreateGroupModal.vue'; import { onClickOutside } from '@vueuse/core'; -// Store and Router setup const router = useRouter(); const route = useRoute(); const authStore = useAuthStore(); @@ -132,30 +116,24 @@ const notificationStore = useNotificationStore(); const groupStore = useGroupStore(); const { t, locale } = useI18n(); -// --- Dropdown Logic (Re-integrated from composable) --- - -// 1. Add Menu Dropdown const addMenuOpen = ref(false); const addMenuDropdown = ref(null); const addMenuTrigger = ref(null); const toggleAddMenu = () => { addMenuOpen.value = !addMenuOpen.value; }; onClickOutside(addMenuDropdown, () => { addMenuOpen.value = false; }, { ignore: [addMenuTrigger] }); -// 2. Language Menu Dropdown const languageMenuOpen = ref(false); const languageMenuDropdown = ref(null); const languageMenuTrigger = ref(null); const toggleLanguageMenu = () => { languageMenuOpen.value = !languageMenuOpen.value; }; onClickOutside(languageMenuDropdown, () => { languageMenuOpen.value = false; }, { ignore: [languageMenuTrigger] }); -// 3. User Menu Dropdown const userMenuOpen = ref(false); const userMenuDropdown = ref(null); const userMenuTrigger = ref(null); const toggleUserMenu = () => { userMenuOpen.value = !userMenuOpen.value; }; onClickOutside(userMenuDropdown, () => { userMenuOpen.value = false; }, { ignore: [userMenuTrigger] }); -// --- Language Selector Logic --- const availableLanguages = computed(() => ({ en: t('languageSelector.languages.en'), de: t('languageSelector.languages.de'), @@ -168,24 +146,23 @@ const currentLanguageCode = computed(() => locale.value); const changeLanguage = (languageCode: string) => { locale.value = languageCode; localStorage.setItem('language', languageCode); - languageMenuOpen.value = false; // Close menu on selection + languageMenuOpen.value = false; notificationStore.addNotification({ type: 'success', message: `Language changed to ${availableLanguages.value[languageCode as keyof typeof availableLanguages.value]}`, }); }; -// --- Modal Handling --- const showCreateListModal = ref(false); const showCreateGroupModal = ref(false); const handleAddList = () => { - addMenuOpen.value = false; // Close menu + addMenuOpen.value = false; showCreateListModal.value = true; }; const handleAddGroup = () => { - addMenuOpen.value = false; // Close menu + addMenuOpen.value = false; showCreateGroupModal.value = true; }; @@ -197,13 +174,12 @@ const handleListCreated = (newList: any) => { const handleGroupCreated = (newGroup: any) => { notificationStore.addNotification({ message: `Group '${newGroup.name}' created successfully`, type: 'success' }); showCreateGroupModal.value = false; - groupStore.fetchGroups(); // Refresh groups after creation + groupStore.fetchGroups(); }; -// --- User and Navigation Logic --- const handleLogout = async () => { try { - userMenuOpen.value = false; // Close menu + userMenuOpen.value = false; authStore.logout(); notificationStore.addNotification({ type: 'success', message: 'Logged out successfully' }); await router.push('/auth/login'); @@ -224,9 +200,7 @@ const navigateToGroups = () => { } }; -// --- App Initialization --- onMounted(async () => { - // Fetch essential data for authenticated users if (authStore.isAuthenticated) { try { await authStore.fetchCurrentUser(); @@ -236,7 +210,6 @@ onMounted(async () => { } } - // Load saved language preference const savedLanguage = localStorage.getItem('language'); if (savedLanguage && Object.keys(availableLanguages.value).includes(savedLanguage)) { locale.value = savedLanguage; @@ -245,14 +218,13 @@ onMounted(async () => { \ No newline at end of file diff --git a/fe/src/pages/GroupDetailPage.vue b/fe/src/pages/GroupDetailPage.vue index 2eb1eae..86ef14d 100644 --- a/fe/src/pages/GroupDetailPage.vue +++ b/fe/src/pages/GroupDetailPage.vue @@ -42,7 +42,6 @@ + -
-
-
{{ t('groupDetailPage.expenses.title') }} @@ -181,7 +178,6 @@ -
{{ t('groupDetailPage.activityLog.title') }}
@@ -199,7 +195,6 @@ -