# app/core/security.py
from datetime import datetime, timedelta, timezone
from typing import Any, Union, Optional

from jose import JWTError, jwt
from passlib.context import CryptContext

from app.config import settings # Import settings from config

# --- Password Hashing ---

# Configure passlib context
# Using bcrypt as the default hashing scheme
# 'deprecated="auto"' will automatically upgrade hashes if needed on verification
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

def verify_password(plain_password: str, hashed_password: str) -> bool:
    """
    Verifies a plain text password against a hashed password.

    Args:
        plain_password: The password attempt.
        hashed_password: The stored hash from the database.

    Returns:
        True if the password matches the hash, False otherwise.
    """
    try:
        return pwd_context.verify(plain_password, hashed_password)
    except Exception:
        # Handle potential errors during verification (e.g., invalid hash format)
        return False

def hash_password(password: str) -> str:
    """
    Hashes a plain text password using the configured context (bcrypt).

    Args:
        password: The plain text password to hash.

    Returns:
        The resulting hash string.
    """
    return pwd_context.hash(password)


# --- JSON Web Tokens (JWT) ---

def create_access_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
    """
    Creates a JWT access token.

    Args:
        subject: The subject of the token (e.g., user ID or email).
        expires_delta: Optional timedelta object for token expiry. If None,
                       uses ACCESS_TOKEN_EXPIRE_MINUTES from settings.

    Returns:
        The encoded JWT access token string.
    """
    if expires_delta:
        expire = datetime.now(timezone.utc) + expires_delta
    else:
        expire = datetime.now(timezone.utc) + timedelta(
            minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
        )

    # Data to encode in the token payload
    to_encode = {"exp": expire, "sub": str(subject), "type": "access"}

    encoded_jwt = jwt.encode(
        to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
    )
    return encoded_jwt

def create_refresh_token(subject: Union[str, Any], expires_delta: Optional[timedelta] = None) -> str:
    """
    Creates a JWT refresh token.

    Args:
        subject: The subject of the token (e.g., user ID or email).
        expires_delta: Optional timedelta object for token expiry. If None,
                       uses REFRESH_TOKEN_EXPIRE_MINUTES from settings.

    Returns:
        The encoded JWT refresh token string.
    """
    if expires_delta:
        expire = datetime.now(timezone.utc) + expires_delta
    else:
        expire = datetime.now(timezone.utc) + timedelta(
            minutes=settings.REFRESH_TOKEN_EXPIRE_MINUTES
        )

    # Data to encode in the token payload
    to_encode = {"exp": expire, "sub": str(subject), "type": "refresh"}

    encoded_jwt = jwt.encode(
        to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM
    )
    return encoded_jwt

def verify_access_token(token: str) -> Optional[dict]:
    """
    Verifies a JWT access token and returns its payload if valid.

    Args:
        token: The JWT token string to verify.

    Returns:
        The decoded token payload (dict) if the token is valid and not expired,
        otherwise None.
    """
    try:
        # Decode the token. This also automatically verifies:
        # - Signature (using SECRET_KEY and ALGORITHM)
        # - Expiration ('exp' claim)
        payload = jwt.decode(
            token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
        )
        if payload.get("type") != "access":
            raise JWTError("Invalid token type")
        return payload
    except JWTError as e:
        # Handles InvalidSignatureError, ExpiredSignatureError, etc.
        print(f"JWT Error: {e}") # Log the error for debugging
        return None
    except Exception as e:
        # Handle other potential unexpected errors during decoding
        print(f"Unexpected error decoding JWT: {e}")
        return None

def verify_refresh_token(token: str) -> Optional[dict]:
    """
    Verifies a JWT refresh token and returns its payload if valid.

    Args:
        token: The JWT token string to verify.

    Returns:
        The decoded token payload (dict) if the token is valid, not expired,
        and is a refresh token, otherwise None.
    """
    try:
        payload = jwt.decode(
            token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
        )
        if payload.get("type") != "refresh":
            raise JWTError("Invalid token type")
        return payload
    except JWTError as e:
        print(f"JWT Error: {e}") # Log the error for debugging
        return None
    except Exception as e:
        print(f"Unexpected error decoding JWT: {e}")
        return None

# You might add a function here later to extract the 'sub' (subject/user id)
# specifically, often used in dependency injection for authentication.
# def get_subject_from_token(token: str) -> Optional[str]:
#     payload = verify_access_token(token)
#     if payload:
#         return payload.get("sub")
#     return None