from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_
from app.models import Expense, RecurrencePattern
from app.crud.expense import create_expense
from app.schemas.expense import ExpenseCreate
import logging
from typing import Optional

logger = logging.getLogger(__name__)

async def generate_recurring_expenses(db: AsyncSession) -> None:
    """
    Background job to generate recurring expenses.
    Should be run daily to check for and create new recurring expenses.
    """
    try:
        # Get all active recurring expenses that need to be generated
        now = datetime.utcnow()
        query = select(Expense).join(RecurrencePattern).where(
            and_(
                Expense.is_recurring == True,
                Expense.next_occurrence <= now,
                # Check if we haven't reached max occurrences
                (
                    (RecurrencePattern.max_occurrences == None) |
                    (RecurrencePattern.max_occurrences > 0)
                ),
                # Check if we haven't reached end date
                (
                    (RecurrencePattern.end_date == None) |
                    (RecurrencePattern.end_date > now)
                )
            )
        )
        
        result = await db.execute(query)
        recurring_expenses = result.scalars().all()
        
        for expense in recurring_expenses:
            try:
                await _generate_next_occurrence(db, expense)
            except Exception as e:
                logger.error(f"Error generating next occurrence for expense {expense.id}: {str(e)}")
                continue
                
    except Exception as e:
        logger.error(f"Error in generate_recurring_expenses job: {str(e)}")
        raise

async def _generate_next_occurrence(db: AsyncSession, expense: Expense) -> None:
    """Generate the next occurrence of a recurring expense."""
    pattern = expense.recurrence_pattern
    if not pattern:
        return
        
    # Calculate next occurrence date
    next_date = _calculate_next_occurrence(expense.next_occurrence, pattern)
    if not next_date:
        return
        
    # Create new expense based on template
    new_expense = ExpenseCreate(
        description=expense.description,
        total_amount=expense.total_amount,
        currency=expense.currency,
        expense_date=next_date,
        split_type=expense.split_type,
        list_id=expense.list_id,
        group_id=expense.group_id,
        item_id=expense.item_id,
        paid_by_user_id=expense.paid_by_user_id,
        is_recurring=False,  # Generated expenses are not recurring
        splits_in=None  # Will be generated based on split_type
    )
    
    # Create the new expense
    created_expense = await create_expense(db, new_expense, expense.created_by_user_id)
    
    # Update the original expense
    expense.last_occurrence = next_date
    expense.next_occurrence = _calculate_next_occurrence(next_date, pattern)
    
    if pattern.max_occurrences:
        pattern.max_occurrences -= 1
        
    await db.flush()

def _calculate_next_occurrence(current_date: datetime, pattern: RecurrencePattern) -> Optional[datetime]:
    """Calculate the next occurrence date based on the pattern."""
    if not current_date:
        return None
        
    if pattern.type == 'daily':
        return current_date + timedelta(days=pattern.interval)
        
    elif pattern.type == 'weekly':
        if not pattern.days_of_week:
            return current_date + timedelta(weeks=pattern.interval)
            
        # Find next day of week
        current_weekday = current_date.weekday()
        next_weekday = min((d for d in pattern.days_of_week if d > current_weekday), 
                          default=min(pattern.days_of_week))
        days_ahead = next_weekday - current_weekday
        if days_ahead <= 0:
            days_ahead += 7
        return current_date + timedelta(days=days_ahead)
        
    elif pattern.type == 'monthly':
        # Add months to current date
        year = current_date.year + (current_date.month + pattern.interval - 1) // 12
        month = (current_date.month + pattern.interval - 1) % 12 + 1
        return current_date.replace(year=year, month=month)
        
    elif pattern.type == 'yearly':
        return current_date.replace(year=current_date.year + pattern.interval)
        
    return None