140 lines
4.7 KiB
Python
140 lines
4.7 KiB
Python
from typing import Optional
|
|
|
|
from fastapi import Depends, Request
|
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
from fastapi_users import BaseUserManager, FastAPIUsers, IntegerIDMixin
|
|
from fastapi_users.authentication import (
|
|
AuthenticationBackend,
|
|
BearerTransport,
|
|
JWTStrategy,
|
|
)
|
|
from fastapi_users.db import SQLAlchemyUserDatabase
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from authlib.integrations.starlette_client import OAuth
|
|
from starlette.config import Config
|
|
from starlette.middleware.sessions import SessionMiddleware
|
|
from starlette.responses import Response
|
|
from pydantic import BaseModel
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from .database import get_session
|
|
from .models import User
|
|
from .config import settings
|
|
|
|
config = Config('.env')
|
|
oauth = OAuth(config)
|
|
|
|
oauth.register(
|
|
name='google',
|
|
server_metadata_url='https://accounts.google.com/.well-known/openid-configuration',
|
|
client_kwargs={
|
|
'scope': 'openid email profile',
|
|
'redirect_uri': settings.GOOGLE_REDIRECT_URI
|
|
}
|
|
)
|
|
|
|
oauth.register(
|
|
name='apple',
|
|
server_metadata_url='https://appleid.apple.com/.well-known/openid-configuration',
|
|
client_kwargs={
|
|
'scope': 'openid email name',
|
|
'redirect_uri': settings.APPLE_REDIRECT_URI
|
|
}
|
|
)
|
|
|
|
class BearerResponseWithRefresh(BaseModel):
|
|
access_token: str
|
|
refresh_token: str
|
|
token_type: str = "bearer"
|
|
|
|
class BearerTransportWithRefresh(BearerTransport):
|
|
async def get_login_response(self, token: str, refresh_token: str = None) -> Response:
|
|
if refresh_token:
|
|
bearer_response = BearerResponseWithRefresh(
|
|
access_token=token,
|
|
refresh_token=refresh_token,
|
|
token_type="bearer"
|
|
)
|
|
else:
|
|
bearer_response = {
|
|
"access_token": token,
|
|
"token_type": "bearer"
|
|
}
|
|
return JSONResponse(bearer_response.dict() if hasattr(bearer_response, 'dict') else bearer_response)
|
|
|
|
class AuthenticationBackendWithRefresh(AuthenticationBackend):
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
transport: BearerTransportWithRefresh,
|
|
get_strategy,
|
|
get_refresh_strategy,
|
|
):
|
|
self.name = name
|
|
self.transport = transport
|
|
self.get_strategy = get_strategy
|
|
self.get_refresh_strategy = get_refresh_strategy
|
|
|
|
async def login(self, strategy, user) -> Response:
|
|
access_token = await strategy.write_token(user)
|
|
refresh_strategy = self.get_refresh_strategy()
|
|
refresh_token = await refresh_strategy.write_token(user)
|
|
|
|
return await self.transport.get_login_response(
|
|
token=access_token,
|
|
refresh_token=refresh_token
|
|
)
|
|
|
|
async def logout(self, strategy, user, token) -> Response:
|
|
return await self.transport.get_logout_response()
|
|
|
|
class UserManager(IntegerIDMixin, BaseUserManager[User, int]):
|
|
reset_password_token_secret = settings.SECRET_KEY
|
|
verification_token_secret = settings.SECRET_KEY
|
|
|
|
async def on_after_register(self, user: User, request: Optional[Request] = None):
|
|
print(f"User {user.id} has registered.")
|
|
|
|
async def on_after_forgot_password(
|
|
self, user: User, token: str, request: Optional[Request] = None
|
|
):
|
|
print(f"User {user.id} has forgot their password. Reset token: {token}")
|
|
|
|
async def on_after_request_verify(
|
|
self, user: User, token: str, request: Optional[Request] = None
|
|
):
|
|
print(f"Verification requested for user {user.id}. Verification token: {token}")
|
|
|
|
async def on_after_login(
|
|
self, user: User, request: Optional[Request] = None, response: Optional[Response] = None
|
|
):
|
|
print(f"User {user.id} has logged in.")
|
|
|
|
async def get_user_db(session: AsyncSession = Depends(get_session)):
|
|
yield SQLAlchemyUserDatabase(session, User)
|
|
|
|
async def get_user_manager(user_db: SQLAlchemyUserDatabase = Depends(get_user_db)):
|
|
yield UserManager(user_db)
|
|
|
|
bearer_transport = BearerTransportWithRefresh(tokenUrl="auth/jwt/login")
|
|
|
|
def get_jwt_strategy() -> JWTStrategy:
|
|
return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=settings.ACCESS_TOKEN_EXPIRE_MINUTES * 60)
|
|
|
|
def get_refresh_jwt_strategy() -> JWTStrategy:
|
|
return JWTStrategy(secret=settings.SECRET_KEY, lifetime_seconds=7 * 24 * 60 * 60)
|
|
|
|
auth_backend = AuthenticationBackendWithRefresh(
|
|
name="jwt",
|
|
transport=bearer_transport,
|
|
get_strategy=get_jwt_strategy,
|
|
get_refresh_strategy=get_refresh_jwt_strategy,
|
|
)
|
|
|
|
fastapi_users = FastAPIUsers[User, int](
|
|
get_user_manager,
|
|
[auth_backend],
|
|
)
|
|
|
|
current_active_user = fastapi_users.current_user(active=True)
|
|
current_superuser = fastapi_users.current_user(active=True, superuser=True) |