Skip to main content

server/
auth_keys.rs

1//! Per-user JWT signing keypairs.
2//!
3//! Each user signs their own tokens with an RSA keypair: the private key lives
4//! only in their per-user database (`user_auth_keys`), the public key in the
5//! global `users.jwt_public_key` directory. Keys are stored as base64-encoded
6//! PEM — the same wire form the web layer's token encode/decode already expects.
7
8use crate::db::{DBError, get_connection};
9use crate::user::User;
10use base64::{Engine as _, engine::general_purpose::STANDARD};
11use rsa::pkcs8::{EncodePrivateKey, EncodePublicKey, LineEnding};
12use rsa::{RsaPrivateKey, RsaPublicKey};
13use sqlx::types::Uuid;
14
15/// RSA modulus size. 2048 is the standard floor for RS256 and far cheaper to
16/// generate than 4096 — important because generation happens during user
17/// provisioning and must not become a CPU-bound DoS lever on the request path.
18const KEY_BITS: usize = 2048;
19
20/// A freshly generated keypair, both halves base64-encoded PEM.
21pub struct KeyPair {
22    pub private_pem_b64: String,
23    pub public_pem_b64: String,
24}
25
26/// Generates an RS256 keypair off the async runtime.
27///
28/// RSA keygen is CPU-bound (tens to hundreds of ms); running it on a
29/// `spawn_blocking` thread keeps it off the request executor so it can't stall
30/// the reactor or starve other tasks.
31///
32/// # Errors
33/// [`DBError::KeyGen`] if generation or PEM encoding fails.
34pub async fn generate() -> Result<KeyPair, DBError> {
35    tokio::task::spawn_blocking(generate_blocking)
36        .await
37        .map_err(|_| DBError::KeyGen)?
38}
39
40fn generate_blocking() -> Result<KeyPair, DBError> {
41    let mut rng = rand::thread_rng();
42    let private = RsaPrivateKey::new(&mut rng, KEY_BITS).map_err(|_| DBError::KeyGen)?;
43    let public = RsaPublicKey::from(&private);
44
45    let private_pem = private
46        .to_pkcs8_pem(LineEnding::LF)
47        .map_err(|_| DBError::KeyGen)?;
48    let public_pem = public
49        .to_public_key_pem(LineEnding::LF)
50        .map_err(|_| DBError::KeyGen)?;
51
52    Ok(KeyPair {
53        private_pem_b64: STANDARD.encode(private_pem.as_bytes()),
54        public_pem_b64: STANDARD.encode(public_pem.as_bytes()),
55    })
56}
57
58/// Fetches a user's base64-PEM private signing key from their per-user
59/// database (`user_auth_keys`, a single-row table). Used on the token MINT path
60/// after the user has been identified.
61///
62/// # Errors
63/// [`DBError::Sqlx`] on a DB error; [`DBError::KeyGen`] if the user has no key
64/// (a provisioning invariant violation — surfaced rather than silently signing
65/// with nothing).
66pub async fn private_key_for(user_id: Uuid) -> Result<String, DBError> {
67    let user = User { id: user_id };
68    let mut conn = user.get_connection().await?;
69    let key: Option<String> = sqlx::query_scalar("SELECT private_key FROM user_auth_keys LIMIT 1")
70        .fetch_optional(&mut *conn)
71        .await?;
72    key.ok_or(DBError::KeyGen)
73}
74
75/// Fetches a user's base64-PEM public verification key from the global `users`
76/// directory. Used on the token VERIFY path, looked up by the (still-unverified)
77/// `sub` claim — so the right key is fetched before the per-user DB is reachable.
78/// Returns `None` if the user is unknown or has no key.
79///
80/// # Errors
81/// [`DBError::Sqlx`] on a DB error.
82pub async fn public_key_for(user_id: Uuid) -> Result<Option<String>, DBError> {
83    let mut conn = get_connection().await?;
84    let key: Option<String> = sqlx::query_scalar("SELECT jwt_public_key FROM users WHERE id = $1")
85        .bind(user_id)
86        .fetch_optional(&mut *conn)
87        .await?
88        .flatten();
89    Ok(key)
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95    use rsa::pkcs8::{DecodePrivateKey, DecodePublicKey};
96
97    fn decode_pem(b64: &str) -> String {
98        String::from_utf8(STANDARD.decode(b64).expect("base64")).expect("utf8")
99    }
100
101    #[tokio::test]
102    async fn generate_produces_decodable_rsa_pem_pair() {
103        let pair = generate().await.expect("keygen");
104
105        let priv_pem = decode_pem(&pair.private_pem_b64);
106        let pub_pem = decode_pem(&pair.public_pem_b64);
107        assert!(priv_pem.contains("BEGIN PRIVATE KEY"));
108        assert!(pub_pem.contains("BEGIN PUBLIC KEY"));
109
110        // The PEM must round-trip back into RSA keys (the web token layer parses
111        // the same base64-PEM form when signing/verifying), and the public half
112        // must match the private half.
113        let parsed_priv = RsaPrivateKey::from_pkcs8_pem(&priv_pem).expect("private parses");
114        let parsed_pub = RsaPublicKey::from_public_key_pem(&pub_pem).expect("public parses");
115        assert_eq!(RsaPublicKey::from(&parsed_priv), parsed_pub);
116    }
117
118    #[tokio::test]
119    async fn generate_yields_distinct_keys() {
120        let a = generate().await.expect("keygen a");
121        let b = generate().await.expect("keygen b");
122        assert_ne!(a.private_pem_b64, b.private_pem_b64);
123        assert_ne!(a.public_pem_b64, b.public_pem_b64);
124    }
125}