# app/main.py
import logging
import uvicorn
from fastapi import FastAPI, HTTPException, Depends, status, Request
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware
import sentry_sdk
from sentry_sdk.integrations.fastapi import FastApiIntegration
from fastapi_users.authentication import JWTStrategy
from pydantic import BaseModel
from jose import jwt, JWTError
from sqlalchemy.ext.asyncio import AsyncEngine
from alembic.config import Config
from alembic import command
import os
import sys

from app.api.api_router import api_router
from app.config import settings
from app.core.api_config import API_METADATA, API_TAGS
from app.auth import fastapi_users, auth_backend, get_refresh_jwt_strategy, get_jwt_strategy
from app.models import User
from app.api.auth.oauth import router as oauth_router
from app.schemas.user import UserPublic, UserCreate, UserUpdate
from app.core.scheduler import init_scheduler, shutdown_scheduler
from app.database import get_session
from sqlalchemy import select

# Response model for refresh endpoint
class RefreshResponse(BaseModel):
    access_token: str
    refresh_token: str
    token_type: str = "bearer"

# Initialize Sentry only if DSN is provided
if settings.SENTRY_DSN:
    sentry_sdk.init(
        dsn=settings.SENTRY_DSN,
        integrations=[
            FastApiIntegration(),
        ],
        # Adjust traces_sample_rate for production
        traces_sample_rate=0.1 if settings.is_production else 1.0,
        environment=settings.ENVIRONMENT,
        # Enable PII data only in development
        send_default_pii=not settings.is_production
    )

# --- Logging Setup ---
logging.basicConfig(
    level=getattr(logging, settings.LOG_LEVEL),
    format=settings.LOG_FORMAT
)
logger = logging.getLogger(__name__)

# --- FastAPI App Instance ---
# Create API metadata with environment-dependent settings
api_metadata = {
    **API_METADATA,
    "docs_url": settings.docs_url,
    "redoc_url": settings.redoc_url,
    "openapi_url": settings.openapi_url,
}

app = FastAPI(
    **api_metadata,
    openapi_tags=API_TAGS
)

# Add session middleware for OAuth
app.add_middleware(
    SessionMiddleware,
    secret_key=settings.SESSION_SECRET_KEY
)

# --- CORS Middleware ---
app.add_middleware(
    CORSMiddleware,
    allow_origins=settings.cors_origins_list,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
    expose_headers=["*"]
)
# --- End CORS Middleware ---

# Refresh token endpoint
@app.post("/auth/jwt/refresh", response_model=RefreshResponse, tags=["auth"])
async def refresh_jwt_token(
    request: Request,
    refresh_strategy: JWTStrategy = Depends(get_refresh_jwt_strategy),
    access_strategy: JWTStrategy = Depends(get_jwt_strategy),
):
    """
    Refresh access token using a valid refresh token.
    Send refresh token in Authorization header: Bearer <refresh_token>
    """
    try:
        # Get refresh token from Authorization header
        authorization = request.headers.get("Authorization")
        if not authorization or not authorization.startswith("Bearer "):
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Refresh token missing or invalid format",
                headers={"WWW-Authenticate": "Bearer"},
            )
        
        refresh_token = authorization.split(" ")[1]
        
        # Validate refresh token and get user data
        try:
            # Decode the refresh token to get the user identifier
            payload = jwt.decode(refresh_token, settings.SECRET_KEY, algorithms=["HS256"])
            user_id = payload.get("sub")
            if user_id is None:
                raise HTTPException(
                    status_code=status.HTTP_401_UNAUTHORIZED,
                    detail="Invalid refresh token",
                )
        except JWTError:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail="Invalid refresh token",
            )
        
        # Get user from database
        async with get_session() as session:
            result = await session.execute(select(User).where(User.id == int(user_id)))
            user = result.scalar_one_or_none()
            
            if not user or not user.is_active:
                raise HTTPException(
                    status_code=status.HTTP_401_UNAUTHORIZED,
                    detail="User not found or inactive",
                )
        
        # Generate new tokens
        new_access_token = await access_strategy.write_token(user)
        new_refresh_token = await refresh_strategy.write_token(user)
        
        return RefreshResponse(
            access_token=new_access_token,
            refresh_token=new_refresh_token,
            token_type="bearer"
        )
        
    except HTTPException:
        raise
    except Exception as e:
        logger.error(f"Error refreshing token: {e}")
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid refresh token"
        )

# --- Include API Routers ---
# Include OAuth routes first (no auth required)
app.include_router(oauth_router, prefix="/auth", tags=["auth"])

# Include FastAPI-Users routes
app.include_router(
    fastapi_users.get_auth_router(auth_backend),
    prefix="/auth/jwt",
    tags=["auth"],
)
app.include_router(
    fastapi_users.get_register_router(UserPublic, UserCreate),
    prefix="/auth",
    tags=["auth"],
)
app.include_router(
    fastapi_users.get_reset_password_router(),
    prefix="/auth",
    tags=["auth"],
)
app.include_router(
    fastapi_users.get_verify_router(UserPublic),
    prefix="/auth",
    tags=["auth"],
)
app.include_router(
    fastapi_users.get_users_router(UserPublic, UserUpdate),
    prefix="/users",
    tags=["users"],
)

# Include your API router
app.include_router(api_router, prefix=settings.API_PREFIX)
# --- End Include API Routers ---

# Health check endpoint
@app.get("/health", tags=["Health"])
async def health_check():
    """
    Health check endpoint for load balancers and monitoring.
    """
    return {
        "status": settings.HEALTH_STATUS_OK,
        "environment": settings.ENVIRONMENT,
        "version": settings.API_VERSION
    }

# --- Root Endpoint (Optional - outside the main API structure) ---
@app.get("/", tags=["Root"])
async def read_root():
    """
    Provides a simple welcome message at the root path.
    Useful for basic reachability checks.
    """
    logger.info("Root endpoint '/' accessed.")
    return {
        "message": settings.ROOT_MESSAGE,
        "environment": settings.ENVIRONMENT,
        "version": settings.API_VERSION
    }
# --- End Root Endpoint ---

async def run_migrations():
    """Run database migrations."""
    try:
        logger.info("Running database migrations...")
        # Get the absolute path to the alembic directory
        base_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        alembic_path = os.path.join(base_path, 'alembic')
        
        # Add alembic directory to Python path
        if alembic_path not in sys.path:
            sys.path.insert(0, alembic_path)
        
        # Import and run migrations
        from migrations import run_migrations as run_db_migrations
        await run_db_migrations()
        
        logger.info("Database migrations completed successfully.")
    except Exception as e:
        logger.error(f"Error running migrations: {e}")
        raise

@app.on_event("startup")
async def startup_event():
    """Initialize services on startup."""
    logger.info(f"Application startup in {settings.ENVIRONMENT} environment...")
    
    # Run database migrations
    # await run_migrations()
    
    # Initialize scheduler
    init_scheduler()
    logger.info("Application startup complete.")

@app.on_event("shutdown")
async def shutdown_event():
    """Cleanup services on shutdown."""
    logger.info("Application shutdown: Disconnecting from database...")
    # await database.engine.dispose() # Close connection pool
    shutdown_scheduler()
    logger.info("Application shutdown complete.")
# --- End Events ---


# --- Direct Run (for simple local testing if needed) ---
# It's better to use `uvicorn app.main:app --reload` from the terminal
# if __name__ == "__main__":
#     logger.info("Starting Uvicorn server directly from main.py")
#     uvicorn.run(app, host="0.0.0.0", port=8000)
# ------------------------------------------------------