1use crate::db::{DBError, get_connection};
9use crate::user::User;
10use base64::{Engine as _, engine::general_purpose::STANDARD};
11use rsa::pkcs8::{EncodePrivateKey, EncodePublicKey, LineEnding};
12use rsa::{RsaPrivateKey, RsaPublicKey};
13use sqlx::types::Uuid;
14
15const KEY_BITS: usize = 2048;
19
20pub struct KeyPair {
22 pub private_pem_b64: String,
23 pub public_pem_b64: String,
24}
25
26pub async fn generate() -> Result<KeyPair, DBError> {
35 tokio::task::spawn_blocking(generate_blocking)
36 .await
37 .map_err(|_| DBError::KeyGen)?
38}
39
40fn generate_blocking() -> Result<KeyPair, DBError> {
41 let mut rng = rand::thread_rng();
42 let private = RsaPrivateKey::new(&mut rng, KEY_BITS).map_err(|_| DBError::KeyGen)?;
43 let public = RsaPublicKey::from(&private);
44
45 let private_pem = private
46 .to_pkcs8_pem(LineEnding::LF)
47 .map_err(|_| DBError::KeyGen)?;
48 let public_pem = public
49 .to_public_key_pem(LineEnding::LF)
50 .map_err(|_| DBError::KeyGen)?;
51
52 Ok(KeyPair {
53 private_pem_b64: STANDARD.encode(private_pem.as_bytes()),
54 public_pem_b64: STANDARD.encode(public_pem.as_bytes()),
55 })
56}
57
58pub async fn private_key_for(user_id: Uuid) -> Result<String, DBError> {
67 let user = User { id: user_id };
68 let mut conn = user.get_connection().await?;
69 let key: Option<String> = sqlx::query_scalar("SELECT private_key FROM user_auth_keys LIMIT 1")
70 .fetch_optional(&mut *conn)
71 .await?;
72 key.ok_or(DBError::KeyGen)
73}
74
75pub async fn public_key_for(user_id: Uuid) -> Result<Option<String>, DBError> {
83 let mut conn = get_connection().await?;
84 let key: Option<String> = sqlx::query_scalar("SELECT jwt_public_key FROM users WHERE id = $1")
85 .bind(user_id)
86 .fetch_optional(&mut *conn)
87 .await?
88 .flatten();
89 Ok(key)
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95 use rsa::pkcs8::{DecodePrivateKey, DecodePublicKey};
96
97 fn decode_pem(b64: &str) -> String {
98 String::from_utf8(STANDARD.decode(b64).expect("base64")).expect("utf8")
99 }
100
101 #[tokio::test]
102 async fn generate_produces_decodable_rsa_pem_pair() {
103 let pair = generate().await.expect("keygen");
104
105 let priv_pem = decode_pem(&pair.private_pem_b64);
106 let pub_pem = decode_pem(&pair.public_pem_b64);
107 assert!(priv_pem.contains("BEGIN PRIVATE KEY"));
108 assert!(pub_pem.contains("BEGIN PUBLIC KEY"));
109
110 let parsed_priv = RsaPrivateKey::from_pkcs8_pem(&priv_pem).expect("private parses");
114 let parsed_pub = RsaPublicKey::from_public_key_pem(&pub_pem).expect("public parses");
115 assert_eq!(RsaPublicKey::from(&parsed_priv), parsed_pub);
116 }
117
118 #[tokio::test]
119 async fn generate_yields_distinct_keys() {
120 let a = generate().await.expect("keygen a");
121 let b = generate().await.expect("keygen b");
122 assert_ne!(a.private_pem_b64, b.private_pem_b64);
123 assert_ne!(a.public_pem_b64, b.public_pem_b64);
124 }
125}