Skip to main content

server/
db.rs

1use crate::config::ConfigError;
2use cfg_if::cfg_if;
3use sqlx::migrate::MigrateError;
4use sqlx::pool::PoolConnection;
5use sqlx::{PgPool, Postgres};
6use thiserror::Error;
7cfg_if! {
8    if #[cfg(test)] {
9    use std::cell::Cell;
10    } else if #[cfg(feature = "test-utils")] {
11    use std::cell::Cell;
12    use sqlx::postgres::PgPoolOptions;
13    use std::env::var;
14    use std::sync::LazyLock;
15    use std::time::Duration;
16    use tokio::sync::OnceCell;
17
18    static DB_URL: LazyLock<String> = LazyLock::new(|| {
19        var("DATABASE_URL")
20        .unwrap_or_else(|_| panic!("{}", String::from(t!("DATABASE_URL is not provided"))))
21    });
22    } else {
23    use sqlx::postgres::PgPoolOptions;
24    use std::env::var;
25    use std::sync::LazyLock;
26    use std::time::Duration;
27    use tokio::sync::OnceCell;
28
29    static DB_URL: LazyLock<String> = LazyLock::new(|| {
30        var("DATABASE_URL")
31        .unwrap_or_else(|_| panic!("{}", String::from(t!("DATABASE_URL is not provided"))))
32    });
33    }
34}
35
36#[derive(Debug, Error)]
37pub enum DBError {
38    #[error("Database error: {0}")]
39    Sqlx(#[from] sqlx::Error),
40    #[error("DB migration error: {0}")]
41    Migration(#[from] MigrateError),
42    #[error("Configuration access error")]
43    Config(#[from] ConfigError),
44    #[error("DATABASE_URL is not provided")]
45    MissingUrl,
46    #[error("The database role lacks CREATEDB privilege")]
47    NoCreateDb,
48    #[error("Failed to generate an authentication keypair")]
49    KeyGen,
50}
51
52pub async fn migrate_db() -> Result<(), DBError> {
53    Ok(sqlx::migrate!("../migrations")
54        .run(get_pool().await?)
55        .await?)
56}
57
58pub async fn get_connection() -> Result<PoolConnection<Postgres>, DBError> {
59    Ok(get_pool().await?.acquire().await?)
60}
61
62/// Runs a multi-statement SQL script against the admin pool via the simple
63/// query protocol. Executed against the shared pool reference (not a
64/// `&mut PgConnection`) so the future stays `Send`/spawnable — the `&mut`
65/// executor form trips a higher-ranked-lifetime bound once `boot()` is spawned.
66pub async fn execute_raw(sql: &str) -> Result<(), DBError> {
67    sqlx::raw_sql(sql).execute(get_pool().await?).await?;
68    Ok(())
69}
70
71cfg_if! {
72    if #[cfg(test)] {
73    // Server's own tests rely on `local_db_sqlx_test` (supp_macro) injecting
74    // a sqlx::test pool into DB_POOL before commands run. Forgetting that
75    // setup is a programming error — panic loudly rather than silently
76    // sharing a real DATABASE_URL pool across tests, which would break the
77    // per-test isolation sqlx::test guarantees.
78    thread_local!(pub static DB_POOL: Cell<*const PgPool> = const {
79        Cell::new(std::ptr::null())
80    });
81
82    /// True when the current thread has installed a test pool. Server-side
83    /// callers consult this to decide between the test pool and the
84    /// per-user production pool.
85    #[must_use]
86    pub fn test_pool_is_set() -> bool {
87        DB_POOL.with(|c| !c.get().is_null())
88    }
89
90    async fn get_pool() -> Result<&'static PgPool, DBError> {
91        let p = DB_POOL.with(|c| c.get());
92        assert!(!p.is_null(), "DB_POOL must be set; see local_db_sqlx_test macro");
93        unsafe { Ok(&*p) }
94    }
95
96    } else if #[cfg(feature = "test-utils")] {
97    // Downstream consumers (the workspace `tests-integration` crate, and
98    // any web/sshd test that wants to drive a `server::command::*` flow
99    // against an isolated DB) install a sqlx::test pool via DB_POOL. When
100    // it is not installed we fall back to the production `DATABASE_URL`
101    // pool — that's the path web's existing tests take when compiled with
102    // `--all-features`, since they never set DB_POOL and instead let the
103    // production pool resolve via JWT'd handlers.
104    thread_local!(pub static DB_POOL: Cell<*const PgPool> = const {
105        Cell::new(std::ptr::null())
106    });
107
108    /// True when the current thread has installed a test pool. Used by
109    /// `User::get_connection` to choose between the test override and the
110    /// per-user production pool.
111    #[must_use]
112    pub fn test_pool_is_set() -> bool {
113        DB_POOL.with(|c| !c.get().is_null())
114    }
115
116    static FALLBACK_POOL: OnceCell<PgPool> = OnceCell::const_new();
117
118    async fn get_pool() -> Result<&'static PgPool, DBError> {
119        let p = DB_POOL.with(|c| c.get());
120        if !p.is_null() {
121            return unsafe { Ok(&*p) };
122        }
123        Ok(FALLBACK_POOL
124           .get_or_init(|| async {
125           log::debug!("Fallback pool initialization");
126           let options = PgPoolOptions::new()
127                       .max_connections(10)
128                       .acquire_timeout(Duration::from_secs(10));
129           options.connect(&DB_URL).await.unwrap()
130           }).await)
131    }
132
133    } else {
134    // Production: lazy-init from DATABASE_URL.
135    static DB_POOL: OnceCell<PgPool> = OnceCell::const_new();
136
137    async fn get_pool() -> Result<&'static PgPool, DBError> {
138            Ok(DB_POOL
139               .get_or_init(|| async {
140           log::debug!("Pool initialization");
141           let options = PgPoolOptions::new()
142                       .max_connections(10)
143                       .acquire_timeout(Duration::from_secs(10));
144           options.connect(&DB_URL).await.unwrap()
145           }).await)
146    }
147    }
148}
149
150// Provisioning helpers. Available outside server's own `cfg(test)` runs (which
151// use isolated sqlx::test pools and never provision); the per-user DSN is
152// derived from the admin `DATABASE_URL`.
153// Provisioning helpers. The per-user DSN is derived from the admin
154// `DATABASE_URL`. Fully-qualified paths keep these independent of the per-arm
155// imports in the `cfg_if!` above so they compile under every cfg (server's own
156// `cfg(test)` never provisions, but the pure helpers stay unit-testable).
157
158/// The admin `DATABASE_URL` the server was configured with.
159///
160/// # Errors
161/// [`DBError::MissingUrl`] if the environment variable is unset.
162pub fn admin_database_url() -> Result<String, DBError> {
163    std::env::var("DATABASE_URL").map_err(|_| DBError::MissingUrl)
164}
165
166/// Runs the full migration set against the database at `url`. Used by
167/// provisioning to bring a freshly-created per-user database up to the
168/// (DDL-only) schema.
169///
170/// # Errors
171/// [`DBError::Sqlx`] on connect failure, [`DBError::Migration`] if a migration
172/// fails.
173pub async fn migrate_url(url: &str) -> Result<(), DBError> {
174    let pool = sqlx::postgres::PgPoolOptions::new()
175        .max_connections(1)
176        .acquire_timeout(std::time::Duration::from_secs(10))
177        .connect(url)
178        .await?;
179    let result = sqlx::migrate!("../migrations").run(&pool).await;
180    // Close the pool on BOTH paths before returning. A still-open connection to
181    // the per-user DB would otherwise block a compensating `DROP DATABASE` on
182    // the migration-failure path.
183    pool.close().await;
184    result.map_err(DBError::from)
185}
186
187#[cfg(test)]
188mod db_tests {
189    use sqlx::PgPool;
190    use tokio::sync::OnceCell;
191
192    /// Context for keeping environment intact
193    static CONTEXT: OnceCell<()> = OnceCell::const_new();
194
195    async fn setup() {
196        CONTEXT
197            .get_or_init(|| async {
198                #[cfg(feature = "testlog")]
199                let _ = env_logger::builder()
200                    .is_test(true)
201                    .filter_level(log::LevelFilter::Trace)
202                    .try_init();
203            })
204            .await;
205    }
206
207    #[sqlx::test(migrations = "../migrations")]
208    async fn migrations_create_schema_without_seed(pool: PgPool) -> sqlx::Result<()> {
209        setup().await;
210
211        let mut conn = pool.acquire().await?;
212
213        // The migration set is DDL-only: it creates the `config` table but no
214        // longer seeds it (seeding moved to `bootstrap::seed`). So a freshly
215        // migrated DB has the table present and empty.
216        let config_rows: i64 = sqlx::query_scalar("SELECT count(*) FROM config")
217            .fetch_one(&mut *conn)
218            .await?;
219        assert_eq!(config_rows, 0, "migration set must not seed config rows");
220
221        Ok(())
222    }
223}