Axum and authentication

So far, our Axum application has been open to the public. In this chapter, we will add authentication to our application. We will start with Basic Authentication and then move on to JSON Web Tokens (JWT).

Basic Authentication

Basic Authentication is a simple authentication scheme built into the HTTP protocol. It is a simple username and password authentication scheme. The client sends the username and password in the Authorization header. The server responds with a 401 Unauthorized status code if the credentials are incorrect.

The Axum add-on crate: axum_extra provides an extractor for Basic Authentication. The TypedHeader<Authorization<Basic>> extractor is used to extract the Authorization header from the request. The Authorization header is then parsed to extract the username and password.

Let's explore Basic Authentication with an example. We will enhance our simple web server and add a reset-visits endpoint that requires Basic Authentication.

First, add the axum-extra crate to your Cargo.toml:

[dependencies]
serde = { version = "1.0.197", features = ["derive"] }
tokio = { version = "1", features = ["full"] }
axum = "0.7"
axum-extra = { version = "0.9", features = ["typed-header"] }
serde_json = "1"

We'll also include serde_json for JSON serialization and deserialization.

Next, let's add the reset-visits endpoint to our web server.

main.rs

use std::sync::Arc;

use axum::{
    extract::{Path, State},
    http::StatusCode,
    response::IntoResponse,
    routing::{delete, get},
    Json, Router,
};
use axum_extra::{
    headers::{authorization::Basic, Authorization},
    TypedHeader,
};
use serde_json::json;

use crate::handler::{GreetingHandler, WebHandler};
use crate::model::Greeting;

mod handler;
mod model;

type AppState = Arc<dyn GreetingHandler>;

#[tokio::main]
async fn main() {
    // Create a shared state for our application. We use an Arc so that we clone the pointer to the state and
    // not the state itself.
    let app_state: AppState = Arc::new(WebHandler::default());

    // set up our application with "hello world" route at "/
    let app = Router::new()
        .route("/hello/:visitor", get(greet_visitor))
        .route("/bye", delete(say_goodbye))
        .route("/reset-visits", delete(reset_visits))
        .with_state(app_state);

    // start the server on port 3000
    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

/// Extract the `visitor` path parameter and use it to greet the visitor.
/// We also use the `State` extractor to access the shared `Handler` and call the `greet` method.
/// We use `Json` to automatically serialize the `Greeting` struct to JSON.
async fn greet_visitor(
    State(handler): State<AppState>,
    Path(visitor): Path<String>,
) -> Json<Greeting> {
    Json(handler.greet(visitor))
}

/// Say goodbye to the visitor.
async fn say_goodbye(State(handler): State<AppState>) -> String {
    handler.say_goodbye()
}

/// Reset the number of visits.
async fn reset_visits(
    TypedHeader(Authorization(creds)): TypedHeader<Authorization<Basic>>,
    State(handler): State<AppState>,
) -> impl IntoResponse {
    if creds.username() != "admin" || creds.password() != "password" {
        return (
            StatusCode::UNAUTHORIZED,
            Json(json!({"error": "Unauthorized"})),
        );
    };

    handler.reset_visits();
    (StatusCode::OK, Json(json!({"ok": "Visits reset"})))
}

And finally, let's add the reset_visits method to our GreetingHandler trait and implementation.

handler.rs

use std::{sync::atomic::AtomicU16, sync::atomic::Ordering::Relaxed};

use crate::model::Greeting;

/// A trait for handling greetings.
pub trait GreetingHandler: Send + Sync {
    fn greet(&self, visitor: String) -> Greeting;
    fn say_goodbye(&self) -> String;
    fn reset_visits(&self);
}

/// A greeting handler implementation for our web application.
pub struct WebHandler {
    number_of_visits: AtomicU16,
}

impl GreetingHandler for WebHandler {
    /// Greet the visitor and increment the number of visits.
    fn greet(&self, visitor: String) -> Greeting {
        let visits = self.number_of_visits.fetch_add(1, Relaxed);
        Greeting::new("Hello", visitor, visits)
    }

    /// Say goodbye to the visitor.
    fn say_goodbye(&self) -> String {
        "Goodbye".to_string()
    }

    /// Reset the number of visits.
    fn reset_visits(&self) {
        self.number_of_visits.store(0, Relaxed);
    }
}

impl Default for WebHandler {
    fn default() -> Self {
        WebHandler {
            number_of_visits: AtomicU16::new(0),
        }
    }
}

OK, let's look at the reset_visits in a bit more detail.

We use the TypedHeader<Authorization<Basic>> extractor to extract the Authorization header from the request. We use the destructor pattern to extract the credentials from the Authorization header: TypedHeader(Authorization(creds)).

We then check if the username and password are correct. If the credentials are incorrect, we return a 401 Unauthorized status code. If the credentials are correct, we call the reset_visits method on our GreetingHandler and return a 200 OK status code.

The json! macro from the serde_json crate is used to create a JSON response. This is a convenient way to create JSON without having to write out the JSON structures manually.

Test the reset-visits endpoint with curl:

curl -X DELETE -u admin:password http://localhost:3000/reset-visits

See what happens when you use the wrong credentials. Or when you omit the credentials altogether.

Json Web Tokens (JWT)

JSON Web Tokens (JWT) is a more common way to authenticate REST requests. JWT is an open standard that defines a compact and self-contained way for securely transmitting information between parties as a JSON object. JWTs can be signed using a secret key or a public/private key pair using RSA or ECDSA. The website jwt.io provides a good introduction to JWT.

We'll use the jsonwebtoken crate to work with JWTs. Add the jsonwebtoken crate to your Cargo.toml:

[dependencies]
serde = { version = "1.0.197", features = ["derive"] }
tokio = { version = "1", features = ["full"] }
axum = "0.7"
axum-extra = { version = "0.9", features = ["typed-header"] }
serde_json = "1"
jsonwebtoken = "9"

Now, let's add JWT authentication to our web server. We'll add a login endpoint that returns a JWT token. We'll then update the reset-visits endpoint to require a JWT token.

main.rs

use std::sync::Arc;

use crate::{
    handler::{GreetingHandler, WebHandler},
    model::{Greeting, OurJwtPayload},
};

use axum::{
    extract::{Path, State},
    http::StatusCode,
    response::IntoResponse,
    routing::{delete, get, post},
    Json, Router,
};
use axum_extra::{
    headers::{
        authorization::{Basic, Bearer},
        Authorization,
    },
    TypedHeader,
};
use jsonwebtoken::{DecodingKey, Validation};
use serde_json::json;

mod handler;
mod model;

const SECRET_SIGNING_KEY: &[u8] = b"keep_th1s_@_secret";

type AppState = Arc<dyn GreetingHandler>;

#[tokio::main]
async fn main() {
    // Create a shared state for our application. We use an Arc so that we clone the pointer to the state and
    // not the state itself.
    let app_state: AppState = Arc::new(WebHandler::default());

    // set up our application with "hello world" route at "/
    let app = Router::new()
        .route("/hello/:visitor", get(greet_visitor))
        .route("/bye", delete(say_goodbye))
        .route("/login", post(login))
        .route("/reset-visits", delete(reset_visits))
        .with_state(app_state);

    // start the server on port 3000
    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

/// Extract the `visitor` path parameter and use it to greet the visitor.
/// We also use the `State` extractor to access the shared `Handler` and call the `greet` method.
/// We use `Json` to automatically serialize the `Greeting` struct to JSON.
async fn greet_visitor(
    State(handler): State<AppState>,
    Path(visitor): Path<String>,
) -> Json<Greeting> {
    Json(handler.greet(visitor))
}

/// Say goodbye to the visitor.
async fn say_goodbye(State(handler): State<AppState>) -> String {
    handler.say_goodbye()
}

/// login endpoint
async fn login(
    TypedHeader(Authorization(creds)): TypedHeader<Authorization<Basic>>,
) -> impl IntoResponse {
    if creds.username() != "admin" || creds.password() != "password" {
        return (
            StatusCode::UNAUTHORIZED,
            Json(json!({"error": "Unauthorized"})),
        );
    };

    let Ok(jwt) = jsonwebtoken::encode(
        &jsonwebtoken::Header::default(),
        &OurJwtPayload::new(creds.username().to_string()),
        &jsonwebtoken::EncodingKey::from_secret(SECRET_SIGNING_KEY),
    ) else {
        return (
            StatusCode::INTERNAL_SERVER_ERROR,
            Json(json!({"error": "Failed to generate token"})),
        );
    };

    (StatusCode::OK, Json(json!({"jwt": jwt})))
}

/// Reset the number of visits.
async fn reset_visits(
    TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>,
    State(handler): State<AppState>,
) -> impl IntoResponse {
    let token = bearer.token();
    let decoding_key = DecodingKey::from_secret(SECRET_SIGNING_KEY);

    let Ok(jwt) =
        jsonwebtoken::decode::<OurJwtPayload>(token, &decoding_key, &Validation::default())
    else {
        return (
            StatusCode::UNAUTHORIZED,
            Json(json!({"error": "Invalid token"})),
        );
    };

    let username = jwt.claims.sub;
    handler.reset_visits();

    (
        StatusCode::OK,
        Json(json!({"ok": format_args!("Visits reset by {username}")})),
    )
}

Update the models.rs file to include the OurJwtPayload struct:

use std::time::{Duration, SystemTime};

use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize)]
pub struct Greeting {
    greeting: String,
    visitor: String,
    visits: u16,
}

impl Greeting {
    pub(crate) fn new(greeting: &str, visitor: String, visits: u16) -> Self {
        Greeting {
            greeting: greeting.to_string(),
            visitor,
            visits,
        }
    }
}

#[derive(Serialize, Deserialize)]
pub struct OurJwtPayload {
    pub sub: String,
    pub exp: usize,
}

impl OurJwtPayload {
    pub fn new(sub: String) -> Self {
        // expires by default in 60 minutes from now
        let exp = SystemTime::now()
            .checked_add(Duration::from_secs(60 * 60))
            .expect("valid timestamp")
            .duration_since(SystemTime::UNIX_EPOCH)
            .expect("valid duration")
            .as_secs() as usize;

        OurJwtPayload { sub, exp }
    }
}

We can now use curl to test the login and reset-visits endpoints:

# login

curl -X POST -u admin:password http://localhost:3000/login

You should get a response similar to:

{
  "jwt": "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJ1c2VybmFtZSI6ImFkbWluIn0.Dwxetf2HuZzABdOV-OwYgHpOBsHnHuNaCYoO0epfuiU"
}

With this token, you can now test the reset-visits endpoint:

curl -X DELETE -H "Authorization: Bearer eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJhZG1pbiIsImV4cCI6MTcxMTYzMDc1OH0.O6bFdi080bS2OIRTcHD2VgeTGwk-r14mfqodsWikZg4" http://localhost:3000/reset-visits

Try to change the token or omit it altogether to see what happens.

You can copy & paste the token and inspect it at jwt.io.

Security

I hope you realize that JWT is not a secure way to store sensitive information. The token is encoded, not encrypted. Anyone with the token can decode it and read its contents. Do not store sensitive information in the token.

The only thing you can trust is the content of the token. The token is signed with a secret key. If the token is tampered with, the signature will not match and the token will be invalid. So always verify the token before using it.

Exercises

This might be a good time to add some web endpoints to your exercise project. Complete the corresponding exercise in the Exercises section. Looking forward to seeing you next time!