Add authorization

This commit is contained in:
2024-05-14 20:55:35 +04:00
parent 0565efdd15
commit 090897a11f
15 changed files with 255 additions and 24 deletions

5
.env.example Normal file
View File

@@ -0,0 +1,5 @@
SQLALCHEMY_DATABASE_URL=sqlite+aiosqlite:///./dev_database.db
IMAGE_TARGET_SIZE=2019600
PREVIEW_TARGET_SIZE=504900
JWT_SECRET_KEY=09d25e094faa6ca2446c818166b7a9565b93f7099f6f2f4caa6cf63b88e8d3e7
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30

View File

@@ -1 +1,2 @@
from .games import * from .games import *
from .users import *

View File

@@ -7,6 +7,14 @@ from .. import schemas as sch
from ..database import add_transaction from ..database import add_transaction
async def get_games(db: AsyncSession):
return (await db.execute(select(mdl.Game))).scalars().all()
async def get_game(db: AsyncSession, game_id: int):
return await db.get(mdl.Game, game_id)
async def add_game(db: AsyncSession, async def add_game(db: AsyncSession,
game_info: sch.GameCreate, game_info: sch.GameCreate,
user_id: int): user_id: int):
@@ -21,18 +29,16 @@ async def edit_game(db: AsyncSession,
game_id: int, game_id: int,
game_info: sch.GameCreate): game_info: sch.GameCreate):
game = await db.get(mdl.Game, game_id) game = await db.get(mdl.Game, game_id)
game_fields = [c.name for c in mdl.Game.__table__.columns] for key, value in vars(game_info).items():
new_game_info = { if (value and value is not None and getattr(game, key) != value):
**{k: v for k, v in vars(game).items() if k in game_fields}, setattr(game, key, value)
**game_info.model_dump()}
print(game_fields, new_game_info)
game = mdl.Game(**new_game_info)
await db.commit() await db.commit()
return game return game
async def get_games(db: AsyncSession): async def delete_game(db: AsyncSession,
return (await db.execute(select(mdl.Game))).scalars().all() game_id: int):
game = await get_game(db, game_id)
async def get_game(db: AsyncSession, game_id: int): await db.delete(game)
return await db.get(mdl.Game, game_id) await db.commit()
return game

26
database/crud/users.py Normal file
View File

@@ -0,0 +1,26 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from .. import models as mdl
from .. import schemas as sch
from ..database import add_transaction
async def get_user(db: AsyncSession, username: str):
return (await db.execute(select(mdl.User).where(mdl.User.name == username))).scalar()
async def add_user(db: AsyncSession,
user_data: sch.UserCreate, hash_of_password: str):
user_data_db = \
{k: v for k, v in user_data.model_dump().items()
if k != "password"}
user = mdl.User(**user_data_db,
hash_of_password=hash_of_password)
return await add_transaction(db, user)
async def check_email(db: AsyncSession, email: str):
users = (await db.execute(select(mdl.User)
.where(mdl.User.email == email))).scalars().all()
return True if len(users) == 0 else False

View File

@@ -2,7 +2,9 @@ from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
DATABASE_URL = "sqlite+aiosqlite:///./dev_database.db" from env import Env
DATABASE_URL = Env.get_strict("SQLALCHEMY_DATABASE_URL", str)
# DATABASE_URL = "postgresql://user:password@postgresserver/db" # DATABASE_URL = "postgresql://user:password@postgresserver/db"
engine = create_async_engine( engine = create_async_engine(

View File

@@ -1 +1,2 @@
from .games import * from .games import *
from .users import *

19
database/schemas/users.py Normal file
View File

@@ -0,0 +1,19 @@
from typing import Optional
from fastapi import Body
from pydantic import BaseModel, ConfigDict, Field
class UserBase(BaseModel):
email: str = Field(examples=["email@gmail.com"])
name: str = Field(examples=["username"])
class UserCreate(UserBase):
password: str = Field(examples=["password"])
class User(UserBase):
id: int = Field(examples=[1])
hash_of_password: str = Field(examples=["hash_of_password"])
model_config = ConfigDict(from_attributes=True)

28
env.py Normal file
View File

@@ -0,0 +1,28 @@
import os
from dotenv import dotenv_values, load_dotenv
class Env:
env: dict[str, str | None] = {
**dotenv_values(".env.example"),
**dotenv_values(".env")
}
@staticmethod
def load_environment(path: str):
load_dotenv(path)
Env.env = {**Env.env, **os.environ}
@staticmethod
def get(key: str):
return Env.env.get(key)
@staticmethod
def get_strict[T](key: str, type_: type[T]) -> T:
env_var = Env.env.get(key)
if (env_var is None):
raise ValueError(f"Environment variable {key} not found")
try:
return type_(env_var)
except:
raise ValueError("Environment variable IMAGE_TARGET_SIZE is wrong")

View File

@@ -7,6 +7,11 @@ import aiofiles
from fastapi import UploadFile from fastapi import UploadFile
from PIL import Image from PIL import Image
from env import Env
IMAGE_TARGET_SIZE = Env.get_strict("IMAGE_TARGET_SIZE", int)
PREVIEW_TARGET_SIZE = Env.get_strict("PREVIEW_TARGET_SIZE", int)
def create_hash_name(filename: str): def create_hash_name(filename: str):
# TODO: Hash from file data # TODO: Hash from file data
@@ -48,8 +53,9 @@ async def save_image(cover: UploadFile, type: Literal["cover", "screenshot"]):
raise ValueError("Invalid image file") raise ValueError("Invalid image file")
cover_full_size = Image.open(BytesIO(cover_data)) cover_full_size = Image.open(BytesIO(cover_data))
compressed_coefficient = (cover_full_size.size[0] * compressed_coefficient = \
cover_full_size.size[1]) / (1920*1080) (cover_full_size.size[0] * cover_full_size.size[1]
) / IMAGE_TARGET_SIZE
if (compressed_coefficient < 1): if (compressed_coefficient < 1):
compressed_coefficient = 1 compressed_coefficient = 1
@@ -63,7 +69,9 @@ async def save_image(cover: UploadFile, type: Literal["cover", "screenshot"]):
await full_size_file.write(buf.getbuffer()) await full_size_file.write(buf.getbuffer())
cover_preview = Image.open(BytesIO(cover_data)) cover_preview = Image.open(BytesIO(cover_data))
compressed_coefficient /= 4 compressed_coefficient /= \
(cover_preview.size[0] * cover_preview.size[1]
) / PREVIEW_TARGET_SIZE
if (compressed_coefficient < 1): if (compressed_coefficient < 1):
compressed_coefficient = 1 compressed_coefficient = 1

View File

@@ -16,6 +16,7 @@ app = FastAPI(
app.include_router(startup_router) app.include_router(startup_router)
app.include_router(games_router) app.include_router(games_router)
app.include_router(files_router) app.include_router(files_router)
app.include_router(auth_router)
app.mount("/content", StaticFiles(directory="content"), name="content") app.mount("/content", StaticFiles(directory="content"), name="content")
cli = typer.Typer() cli = typer.Typer()

View File

@@ -5,3 +5,8 @@ aiosqlite==0.20.0
typer==0.12.3 typer==0.12.3
aiofiles==23.2.1 aiofiles==23.2.1
Pillow==10.3.0 Pillow==10.3.0
bcrypt==4.1.3
passlib==1.7.4
cryptography==42.0.7
python-jose==3.3.0
python-dotenv==1.0.1

View File

@@ -1,3 +1,4 @@
from .games import games_router as games_router from .games import games_router as games_router
from .files import files_router as files_router from .files import files_router as files_router
from .startup import startup_router as startup_router from .startup import startup_router as startup_router
from .auth import auth_router as auth_router

115
routes/auth.py Normal file
View File

@@ -0,0 +1,115 @@
from typing import Annotated, Any
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from passlib.context import CryptContext
from datetime import datetime, timedelta, timezone
from fastapi import APIRouter, Depends, status, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel
from jose import JWTError, jwt
import database as db
from env import Env
SECRET_KEY = Env.get_strict("JWT_SECRET_KEY", str)
ACCESS_TOKEN_EXPIRE_MINUTES = \
Env.get_strict("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", int)
crypt = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="auth")
auth_router = APIRouter(prefix="/auth", tags=["Auth"])
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: str
email: str
def check_password(password, hash): return crypt.verify(password, hash)
def get_hash(password): return crypt.hash(password)
async def get_user(token: str = Depends(oauth2_scheme),
db_session: AsyncSession = Depends(db.get_session)):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY)
token_data = TokenData(**payload)
except Exception:
raise credentials_exception
user = await db.get_user(db_session, token_data.username)
if user is None:
raise credentials_exception
return user
def create_token(user: db.User):
access_token_expires = \
timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.now(timezone.utc) + access_token_expires
to_encode = {
"username": user.name,
"email": user.email,
"expire": str(expire)
}
encoded_jwt = jwt.encode(to_encode, SECRET_KEY)
return Token(access_token=encoded_jwt, token_type="bearer")
@auth_router.post("/registration")
async def registration_user(
user_data: db.UserCreate,
db_session: AsyncSession = Depends(db.get_session)
) -> Token:
if (not await db.check_email(db_session, user_data.email)):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="This email is occupied by another user",
headers={"WWW-Authenticate": "Bearer"},
)
elif (await db.get_user(db_session, user_data.name) is not None):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User with the same name already exists",
headers={"WWW-Authenticate": "Bearer"},
)
else:
user = await db.add_user(db_session, user_data,
get_hash(user_data.password))
return create_token(user)
@auth_router.post("")
async def login_user(
auth_data: OAuth2PasswordRequestForm = Depends(),
db_session: AsyncSession = Depends(db.get_session)
) -> Token:
user = await db.get_user(db_session, auth_data.username)
if (user is None):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
headers={"WWW-Authenticate": "Bearer"},
)
if (not check_password(auth_data.password, user.hash_of_password)):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect password",
headers={"WWW-Authenticate": "Bearer"},
)
return create_token(user)
@auth_router.get("/me", response_model=db.User)
async def read_me(user: db.User = Depends(get_user)):
return user

View File

@@ -3,15 +3,23 @@ from fastapi import APIRouter, Depends
import database as db import database as db
from file_handler import * from file_handler import *
from routes.auth import get_user
games_router = APIRouter(prefix="/games", tags=["Games"]) games_router = APIRouter(prefix="/games", tags=["Games"])
@games_router.get("/", response_model=list[db.Game]) @games_router.get("", response_model=list[db.Game])
async def get_games(db_session: AsyncSession = Depends(db.get_session)): async def get_games(db_session: AsyncSession = Depends(db.get_session)):
return await db.get_games(db_session) return await db.get_games(db_session)
@games_router.post("", response_model=db.Game)
async def add_game(game: db.GameCreate,
user: db.User = Depends(get_user),
db_session: AsyncSession = Depends(db.get_session)):
return await db.add_game(db_session, game, user.id)
@games_router.get("/cards", response_model=list[db.GameCard]) @games_router.get("/cards", response_model=list[db.GameCard])
async def get_games_cards(db_session: AsyncSession = Depends(db.get_session)): async def get_games_cards(db_session: AsyncSession = Depends(db.get_session)):
return await db.get_games(db_session) return await db.get_games(db_session)
@@ -29,8 +37,7 @@ async def edit_game(game_id: int,
return await db.edit_game(db_session, game_id, game) return await db.edit_game(db_session, game_id, game)
@games_router.post("/", response_model=db.Game) @games_router.delete("/{game_id}", response_model=db.Game)
async def add_game(game: db.GameCreate, async def delete_game(game_id: int,
user_id: int,
db_session: AsyncSession = Depends(db.get_session)): db_session: AsyncSession = Depends(db.get_session)):
return await db.add_game(db_session, game, user_id) return await db.delete_game(db_session, game_id)

View File

@@ -1,11 +1,12 @@
from fastapi import APIRouter from fastapi import APIRouter
from pathlib import Path from pathlib import Path
from env import Env
startup_router = APIRouter() startup_router = APIRouter()
@startup_router.on_event("startup") def create_folders():
def startup():
need_paths = [ need_paths = [
Path() / "content" / "images" / "cover" / "full_size", Path() / "content" / "images" / "cover" / "full_size",
Path() / "content" / "images" / "cover" / "preview", Path() / "content" / "images" / "cover" / "preview",
@@ -15,3 +16,8 @@ def startup():
] ]
for path in need_paths: for path in need_paths:
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
@startup_router.on_event("startup")
def startup():
create_folders()