Skip to content

Authentication with JWT, Part 1 #188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
29a0453
WIP not quite working
lotif Apr 29, 2025
7a56d46
WIP generally working now, need ot work on hashing
lotif Apr 30, 2025
ef1e313
Now working with double hashing :)
lotif Apr 30, 2025
ac41030
Small refactorings and a fix
lotif Apr 30, 2025
fd05781
moving auth to its own route
lotif Apr 30, 2025
d7dca4b
Making client auth route, adding hashed password to the client and to…
lotif Apr 30, 2025
4dcc311
Client auth done, fixed unit tests
lotif May 1, 2025
125b101
Starting the login page
lotif May 1, 2025
cef80ae
Login form submit working
lotif May 2, 2025
6eb83ad
Saving the token and redirecting the user
lotif May 2, 2025
050b127
forgot to submit package json
lotif May 2, 2025
b70db0e
checking the token on the sidebar
lotif May 2, 2025
6fb3cba
Redirect to login on 401, enabling auth in all calls
lotif May 2, 2025
d20fbd2
Done with most of it, needs tests
lotif May 2, 2025
27afe1e
Fixed unit tests, added token tests
lotif May 5, 2025
19ccba6
More tests
lotif May 5, 2025
21e6ae0
Adding test for auth
lotif May 5, 2025
f631b4d
Starting to implement front end tests
lotif May 6, 2025
e658bae
Finished front end tests
lotif May 6, 2025
42881cb
adding test for expired token
lotif May 6, 2025
bb96a68
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 6, 2025
88fca2c
Adding warning to the hash function
lotif May 6, 2025
5b545f8
Merge remote-tracking branch 'origin/auth' into auth
lotif May 6, 2025
fa748ca
Obscuring the client hashed password returned by the back end, fixing…
lotif May 7, 2025
2ffd7fa
better return statement on check token functions
lotif May 8, 2025
55fbcf1
small code cleanup
lotif May 8, 2025
22f8e1f
[pre-commit.ci] Add auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions florist/api/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Module for handling authentication and token creation."""
148 changes: 148 additions & 0 deletions florist/api/auth/token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""Module for handling token and user creation."""

import hashlib
from datetime import datetime, timedelta, timezone
from typing import Any

import bcrypt
import jwt
from motor.motor_asyncio import AsyncIOMotorDatabase
from pydantic import BaseModel

from florist.api.db.client_entities import UserDAO
from florist.api.db.server_entities import User


ENCRYPTION_ALGORITHM = "HS256"
DEFAULT_USERNAME = "admin"
DEFAULT_PASSWORD = "admin"
TOKEN_EXPIRATION_TIMEDELTA = timedelta(days=7)


class Token(BaseModel):
"""Define the Token model."""

access_token: str
token_type: str

class Config:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably going to be a dumb question, by what is the point of these inner classes? Perhaps they are expected to be defined for objects inheriting from BaseModel and are therefore used somewhere in outside library code, but I don't see where these get used in our stuff 🙂 (Some question for the inner class of AuthUser.

"""Config for the Token model."""

allow_population_by_field_name = True
schema_extra = {
"example": {
"access_token": "LQv3c1yqBWVHxkd0LHAkCOYz6T",
"token_type": "bearer",
},
}


class AuthUser(BaseModel):
"""Define the User model to be returned by the API."""

uuid: str
username: str

class Config:
"""Config for the AuthUser model."""

allow_population_by_field_name = True
schema_extra = {
"example": {
"uuid": "LQv3c1yqBWVHxkd0LHAkCOYz6T",
"username": "admin",
},
}


def _simple_hash(word: str) -> str:
"""
Hash a word with sha256.

WARNING: This is not a secure hash function, it is only meant to obscure
plain text words. DO NOT use this for generating encrypted passwords for the
authentication users. For that, use the _password_hash function instead.

:param word: (str) the word to hash.
:return: (str) the word hashed as a sha256 hexadecimal string.
"""
return hashlib.sha256(word.encode("utf-8")).hexdigest()


def _password_hash(password: str) -> str:
"""
Hash a password with bcrypt.

:param password: (str) the password to hash.
:return: (str) the hashed password.
"""
password_bytes = password.encode("utf-8")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is perhaps too pedantic, but should we first check that the password given is utf-8 admissible? I'm thinking of what might happen if someone tries to use an emoji as a password 😂

salt = bcrypt.gensalt()
hashed_password = bcrypt.hashpw(password=password_bytes, salt=salt)
return hashed_password.decode("utf-8")


def verify_password(password: str, hashed_password: str) -> bool:
"""
Verify if a password matches a hashed password.

:param password: (str) the password to verify.
:param hashed_password: (str) the hashed password to verify against.
:return: (bool) True if the password matches the hashed password, False otherwise.
"""
return bcrypt.checkpw(password.encode("utf-8"), hashed_password.encode("utf-8"))


async def make_default_server_user(database: AsyncIOMotorDatabase[Any]) -> User:
"""
Make a default server user.

:param database: (AsyncIOMotorDatabase[Any]) the database to create the user in.
:return: (User) the default server user.
"""
hashed_password = _password_hash(_simple_hash(DEFAULT_PASSWORD))
user = User(username=DEFAULT_USERNAME, hashed_password=hashed_password)
await user.create(database)
return user


def make_default_client_user() -> UserDAO:
"""
Make a default client user.

:return: (User) the default client user.
"""
hashed_password = _password_hash(_simple_hash(DEFAULT_PASSWORD))
user = UserDAO(username=DEFAULT_USERNAME, hashed_password=hashed_password)
user.save()
return user


def create_access_token(
data: dict[str, Any], secret_key: str, expiration_delta: timedelta = TOKEN_EXPIRATION_TIMEDELTA
) -> str:
"""
Create an access token.

:param data: (dict) the data to encode in the token.
:param secret_key: (str) the user's secret key to encode the token.
:param expiration_delta: (timedelta) the expiration time of the token.
:return: (str) the access token.
"""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + expiration_delta
to_encode.update({"exp": expire})
return jwt.encode(to_encode, secret_key, algorithm=ENCRYPTION_ALGORITHM)


def decode_access_token(token: str, secret_key: str) -> dict[str, Any]:
"""
Decode an access token.

:param token: (str) the token to decode.
:param secret_key: (str) the user's secret key to decode the token.
:return: (dict) the decoded token information.
"""
data = jwt.decode(token, secret_key, algorithms=[ENCRYPTION_ALGORITHM])
assert isinstance(data, dict)
return data
36 changes: 27 additions & 9 deletions florist/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,46 @@
import logging
import os
import signal
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, AsyncGenerator
from uuid import uuid4

import torch
from fastapi import FastAPI
from fastapi import Depends, FastAPI
from fastapi.responses import JSONResponse
from fl4health.utils.metrics import Accuracy

from florist.api.auth.token import DEFAULT_USERNAME, make_default_client_user
from florist.api.clients.clients import Client
from florist.api.clients.optimizers import Optimizer
from florist.api.db.client_entities import ClientDAO
from florist.api.db.client_entities import ClientDAO, UserDAO
from florist.api.launchers.local import launch_client
from florist.api.models.models import Model
from florist.api.monitoring.logs import get_client_log_file_path
from florist.api.monitoring.metrics import RedisMetricsReporter, get_from_redis, get_host_and_port_from_address
from florist.api.routes.client.auth import check_token
from florist.api.routes.client.auth import router as auth_router


app = FastAPI()
LOGGER = logging.getLogger("uvicorn.error")


LOGGER = logging.getLogger("uvicorn.error")
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[Any, Any]:
"""Set up function for app startup and shutdown."""
# Create default user if it does not exist
if not UserDAO.exists(DEFAULT_USERNAME):
make_default_client_user()

yield


app = FastAPI(lifespan=lifespan)
app.include_router(auth_router, tags=["auth"], prefix="/api/client/auth")

@app.get("/api/client/connect")

@app.get("/api/client/connect", dependencies=[Depends(check_token)])
def connect() -> JSONResponse:
"""
Confirm the client is up and ready to accept instructions.
Expand All @@ -36,7 +52,7 @@ def connect() -> JSONResponse:
return JSONResponse({"status": "ok"})


@app.get("/api/client/start")
@app.get("/api/client/start", dependencies=[Depends(check_token)])
def start(
server_address: str,
client: Client,
Expand All @@ -53,6 +69,7 @@ def start(
:param client: (Client) the client to be used for training.
:param data_path: (str) the path where the training data is located.
:param redis_address: (str) the address for the Redis instance for metrics reporting.

:return: (JSONResponse) If successful, returns 200 with a JSON containing the UUID for the client in the
format below, which can be used to pull metrics from Redis.
{
Expand Down Expand Up @@ -95,7 +112,7 @@ def start(
return JSONResponse({"error": str(ex)}, status_code=500)


@app.get("/api/client/check_status/{client_uuid}")
@app.get("/api/client/check_status/{client_uuid}", dependencies=[Depends(check_token)])
def check_status(client_uuid: str, redis_address: str) -> JSONResponse:
"""
Retrieve value at key client_uuid in redis if it exists.
Expand All @@ -120,7 +137,7 @@ def check_status(client_uuid: str, redis_address: str) -> JSONResponse:
return JSONResponse({"error": str(ex)}, status_code=500)


@app.get("/api/client/get_log/{uuid}")
@app.get("/api/client/get_log/{uuid}", dependencies=[Depends(check_token)])
def get_log(uuid: str) -> JSONResponse:
"""
Return the contents of the logs for the given client uuid.
Expand All @@ -147,12 +164,13 @@ def get_log(uuid: str) -> JSONResponse:
return JSONResponse({"error": str(ex)}, status_code=500)


@app.get("/api/client/stop/{uuid}")
@app.get("/api/client/stop/{uuid}", dependencies=[Depends(check_token)])
def stop(uuid: str) -> JSONResponse:
"""
Stop the client with given UUID.

:param uuid: (str) the UUID of the client to be stopped.

:return: (JSONResponse) If successful, returns 200. If not successful, returns the appropriate
error code with a JSON with the format below:
{"error": <error message>}
Expand Down
55 changes: 54 additions & 1 deletion florist/api/db/client_entities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Definitions for the SQLIte database entities (client database)."""

import json
import secrets
import sqlite3
from abc import ABC, abstractmethod
from typing import Optional
Expand Down Expand Up @@ -55,7 +56,7 @@ def find(cls, uuid: str) -> Self:
for result in results:
return cls.from_json(result[1])

raise ValueError(f"Client with uuid '{uuid}' not found.")
raise ValueError(f"{cls.table_name} with uuid '{uuid}' not found.")

@classmethod
def exists(cls, uuid: str) -> bool:
Expand Down Expand Up @@ -167,3 +168,55 @@ def to_json(self) -> str:
"pid": self.pid,
}
)


class UserDAO(EntityDAO):
"""Data Access Object (DAO) for the User SQLite entity."""

table_name = "User"

def __init__(self, username: str, hashed_password: str):
"""
Initialize a User entity.

:param uuid: (str) the UUID of the user.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be missing something here, but I don't think uuid is a param?

:param username: (str) the username of the user.
:param hashed_password: (str) the hashed password of the user.
"""
# The UUID for the user is the username
super().__init__(uuid=username)

self.username = username
self.hashed_password = hashed_password

# always create a new random secret key
self.secret_key = secrets.token_hex(32)

@classmethod
def from_json(cls, json_data: str) -> Self:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason you don't just type this as UserDAO?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I honestly don't know which is the canonical "right" way to do this, but you can use the more "obvious" type UserDAO if you import from __future__ import annotations

"""
Convert from a JSON string into an instance of User.

:param json_data: the user's data as a JSON string.
:return: (Self) and instancxe of UserDAO populated with the JSON data.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instance?

"""
data = json.loads(json_data)
user = cls(data["username"], data["hashed_password"])
user.uuid = data["uuid"]
user.secret_key = data["secret_key"]
return user

def to_json(self) -> str:
"""
Convert the user data into a JSON string.

:return: (str) the user data as a JSON string.
"""
return json.dumps(
{
"uuid": self.uuid,
"username": self.username,
"hashed_password": self.hashed_password,
"secret_key": self.secret_key,
}
)
Loading