Stateful Web with Axum

So far, we've seen how we can build a simple web app with axum. Let's add some state to our web application, and build a global counter of the number of requests to our REST API. We'll use an AtomicU16 to keep track of the number of visits. The AtomicU16 is a thread-safe integer that we can use to increment the number of visits, without worrying about concurrent access.

By using an Arc we can share the state between the different request handlers. The Arc is a reference-counted pointer that allows us to share the state between different threads. The State extractor allows us to access the shared state in our request handlers.

So essentially all requests are pointing to the same AppState struct, and we can use the AtomicU16 to keep track of the number of visits across all requests.

use axum::{
    extract::{Path, State},
    routing::{delete, get},
    Json, Router,
};
use serde::{Deserialize, Serialize};
use std::sync::{atomic::AtomicU16, atomic::Ordering::Relaxed, Arc};

#[derive(Serialize, Deserialize)]
struct Greeting {
    greeting: String,
    visitor: String,
    visits: u16,
}

struct AppState {
    number_of_visits: AtomicU16,
}

impl Greeting {
    fn new(greeting: &str, visitor: String, visits: u16) -> Self {
        Greeting {
            greeting: greeting.to_string(),
            visitor,
            visits,
        }
    }
}

#[tokio::main]
async fn main() {
    // Create a shared state for our application. We use an Arc so that we clone the pointer to the state and
    // not the state itself. The AtomicU16 is a thread-safe integer that we use to keep track of the number of visits.
    let app_state = Arc::new(AppState {
        number_of_visits: AtomicU16::new(1),
    });

    // setup our application with "hello world" route at "/
    let app = Router::new()
        .route("/hello/:visitor", get(greet_visitor))
        .route("/bye", delete(say_goodbye))
        .with_state(app_state);

    // start the server on port 3000
    let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

/// Extract the `visitor` path parameter and use it to greet the visitor.
/// We also use the `State` extractor to access the shared `AppState` and increment the number of visits.
/// We use `Json` to automatically serialize the `Greeting` struct to JSON.
async fn greet_visitor(
    State(app_state): State<Arc<AppState>>,
    Path(visitor): Path<String>,
) -> Json<Greeting> {
    let visits = app_state
        .number_of_visits
        .fetch_add(1, Relaxed);
    Json(Greeting::new("Hello", visitor, visits))
}

/// Say goodbye to the visitor.
async fn say_goodbye() -> String {
    "Goodbye".to_string()
}

You can test the web server by running the following command:

$ curl http://127.0.0.1:3000/hello/world

Result:

{
  "greeting": "Hello",
  "visitor": "world",
  "visits": 1
}

Some statistics using Apache Bench

ab -n 1000 -c 1 http://127.0.0.1:3000/hello/world

Concurrency Level:      1
Time taken for tests:   0.112 seconds
Complete requests:      1000
ab -n 1000 -c 100 http://127.0.0.1:3000/hello/world 

Concurrency Level:      100
Time taken for tests:   0.072 seconds
Complete requests:      1000

Disregarding the actual numbers, this shows that our axum server is processing the requests in parallel as much as possible.