from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy import union_all, or_ from typing import List, Optional from app.models import FinancialAuditLog, Base, User, Group, Expense, Settlement from app.schemas.audit import FinancialAuditLogCreate async def create_financial_audit_log( db: AsyncSession, *, user_id: int | None, action_type: str, entity: Base, details: dict | None = None ) -> FinancialAuditLog: log_entry_data = FinancialAuditLogCreate( user_id=user_id, action_type=action_type, entity_type=entity.__class__.__name__, entity_id=entity.id, details=details ) log_entry = FinancialAuditLog(**log_entry_data.dict()) db.add(log_entry) await db.commit() await db.refresh(log_entry) return log_entry async def get_financial_audit_logs_for_group(db: AsyncSession, *, group_id: int, skip: int = 0, limit: int = 100) -> List[FinancialAuditLog]: """ Get financial audit logs for all entities that belong to a specific group. This includes Expenses and Settlements that are linked to the group. """ # Get all expense IDs for this group expense_ids_query = select(Expense.id).where(Expense.group_id == group_id) expense_result = await db.execute(expense_ids_query) expense_ids = [row[0] for row in expense_result.fetchall()] # Get all settlement IDs for this group settlement_ids_query = select(Settlement.id).where(Settlement.group_id == group_id) settlement_result = await db.execute(settlement_ids_query) settlement_ids = [row[0] for row in settlement_result.fetchall()] # Build conditions for the audit log query conditions = [] if expense_ids: conditions.append( (FinancialAuditLog.entity_type == 'Expense') & (FinancialAuditLog.entity_id.in_(expense_ids)) ) if settlement_ids: conditions.append( (FinancialAuditLog.entity_type == 'Settlement') & (FinancialAuditLog.entity_id.in_(settlement_ids)) ) # If no entities exist for this group, return empty list if not conditions: return [] # Query audit logs for all relevant entities query = select(FinancialAuditLog).where( or_(*conditions) ).order_by(FinancialAuditLog.timestamp.desc()).offset(skip).limit(limit) result = await db.execute(query) return result.scalars().all() async def get_financial_audit_logs_for_user(db: AsyncSession, *, user_id: int, skip: int = 0, limit: int = 100) -> List[FinancialAuditLog]: result = await db.execute( select(FinancialAuditLog) .where(FinancialAuditLog.user_id == user_id) .order_by(FinancialAuditLog.timestamp.desc()) .offset(skip).limit(limit) ) return result.scalars().all()