1
use crate::config::ConfigError;
2
use cfg_if::cfg_if;
3
use sqlx::migrate::MigrateError;
4
use sqlx::pool::PoolConnection;
5
use sqlx::{PgPool, Postgres};
6
use thiserror::Error;
7
cfg_if! {
8
    if #[cfg(test)] {
9
    use std::cell::Cell;
10
    } else if #[cfg(feature = "test-utils")] {
11
    use std::cell::Cell;
12
    use sqlx::postgres::PgPoolOptions;
13
    use std::env::var;
14
    use std::sync::LazyLock;
15
    use std::time::Duration;
16
    use tokio::sync::OnceCell;
17

            
18
108
    static DB_URL: LazyLock<String> = LazyLock::new(|| {
19
108
        var("DATABASE_URL")
20
108
        .unwrap_or_else(|_| panic!("{}", String::from(t!("DATABASE_URL is not provided"))))
21
108
    });
22
    } else {
23
    use sqlx::postgres::PgPoolOptions;
24
    use std::env::var;
25
    use std::sync::LazyLock;
26
    use std::time::Duration;
27
    use tokio::sync::OnceCell;
28

            
29
    static DB_URL: LazyLock<String> = LazyLock::new(|| {
30
        var("DATABASE_URL")
31
        .unwrap_or_else(|_| panic!("{}", String::from(t!("DATABASE_URL is not provided"))))
32
    });
33
    }
34
}
35

            
36
#[derive(Debug, Error)]
37
pub enum DBError {
38
    #[error("Database error: {0}")]
39
    Sqlx(#[from] sqlx::Error),
40
    #[error("DB migration error: {0}")]
41
    Migration(#[from] MigrateError),
42
    #[error("Configuration access error")]
43
    Config(#[from] ConfigError),
44
    #[error("DATABASE_URL is not provided")]
45
    MissingUrl,
46
    #[error("The database role lacks CREATEDB privilege")]
47
    NoCreateDb,
48
    #[error("Failed to generate an authentication keypair")]
49
    KeyGen,
50
}
51

            
52
pub async fn migrate_db() -> Result<(), DBError> {
53
    Ok(sqlx::migrate!("../migrations")
54
        .run(get_pool().await?)
55
        .await?)
56
}
57

            
58
8399
pub async fn get_connection() -> Result<PoolConnection<Postgres>, DBError> {
59
8399
    Ok(get_pool().await?.acquire().await?)
60
8399
}
61

            
62
/// Runs a multi-statement SQL script against the admin pool via the simple
63
/// query protocol. Executed against the shared pool reference (not a
64
/// `&mut PgConnection`) so the future stays `Send`/spawnable — the `&mut`
65
/// executor form trips a higher-ranked-lifetime bound once `boot()` is spawned.
66
351
pub async fn execute_raw(sql: &str) -> Result<(), DBError> {
67
351
    sqlx::raw_sql(sql).execute(get_pool().await?).await?;
68
324
    Ok(())
69
351
}
70

            
71
cfg_if! {
72
    if #[cfg(test)] {
73
    // Server's own tests rely on `local_db_sqlx_test` (supp_macro) injecting
74
    // a sqlx::test pool into DB_POOL before commands run. Forgetting that
75
    // setup is a programming error — panic loudly rather than silently
76
    // sharing a real DATABASE_URL pool across tests, which would break the
77
    // per-test isolation sqlx::test guarantees.
78
    thread_local!(pub static DB_POOL: Cell<*const PgPool> = const {
79
        Cell::new(std::ptr::null())
80
    });
81

            
82
    /// True when the current thread has installed a test pool. Server-side
83
    /// callers consult this to decide between the test pool and the
84
    /// per-user production pool.
85
    #[must_use]
86
    pub fn test_pool_is_set() -> bool {
87
        DB_POOL.with(|c| !c.get().is_null())
88
    }
89

            
90
1217
    async fn get_pool() -> Result<&'static PgPool, DBError> {
91
1217
        let p = DB_POOL.with(|c| c.get());
92
1217
        assert!(!p.is_null(), "DB_POOL must be set; see local_db_sqlx_test macro");
93
1217
        unsafe { Ok(&*p) }
94
1217
    }
95

            
96
    } else if #[cfg(feature = "test-utils")] {
97
    // Downstream consumers (the workspace `tests-integration` crate, and
98
    // any web/sshd test that wants to drive a `server::command::*` flow
99
    // against an isolated DB) install a sqlx::test pool via DB_POOL. When
100
    // it is not installed we fall back to the production `DATABASE_URL`
101
    // pool — that's the path web's existing tests take when compiled with
102
    // `--all-features`, since they never set DB_POOL and instead let the
103
    // production pool resolve via JWT'd handlers.
104
    thread_local!(pub static DB_POOL: Cell<*const PgPool> = const {
105
        Cell::new(std::ptr::null())
106
    });
107

            
108
    /// True when the current thread has installed a test pool. Used by
109
    /// `User::get_connection` to choose between the test override and the
110
    /// per-user production pool.
111
    #[must_use]
112
6885
    pub fn test_pool_is_set() -> bool {
113
6885
        DB_POOL.with(|c| !c.get().is_null())
114
6885
    }
115

            
116
    static FALLBACK_POOL: OnceCell<PgPool> = OnceCell::const_new();
117

            
118
7533
    async fn get_pool() -> Result<&'static PgPool, DBError> {
119
7533
        let p = DB_POOL.with(|c| c.get());
120
7533
        if !p.is_null() {
121
5670
            return unsafe { Ok(&*p) };
122
1863
        }
123
1863
        Ok(FALLBACK_POOL
124
1863
           .get_or_init(|| async {
125
108
           log::debug!("Fallback pool initialization");
126
108
           let options = PgPoolOptions::new()
127
108
                       .max_connections(10)
128
108
                       .acquire_timeout(Duration::from_secs(10));
129
108
           options.connect(&DB_URL).await.unwrap()
130
1863
           }).await)
131
7533
    }
132

            
133
    } else {
134
    // Production: lazy-init from DATABASE_URL.
135
    static DB_POOL: OnceCell<PgPool> = OnceCell::const_new();
136

            
137
    async fn get_pool() -> Result<&'static PgPool, DBError> {
138
            Ok(DB_POOL
139
               .get_or_init(|| async {
140
           log::debug!("Pool initialization");
141
           let options = PgPoolOptions::new()
142
                       .max_connections(10)
143
                       .acquire_timeout(Duration::from_secs(10));
144
           options.connect(&DB_URL).await.unwrap()
145
           }).await)
146
    }
147
    }
148
}
149

            
150
// Provisioning helpers. Available outside server's own `cfg(test)` runs (which
151
// use isolated sqlx::test pools and never provision); the per-user DSN is
152
// derived from the admin `DATABASE_URL`.
153
// Provisioning helpers. The per-user DSN is derived from the admin
154
// `DATABASE_URL`. Fully-qualified paths keep these independent of the per-arm
155
// imports in the `cfg_if!` above so they compile under every cfg (server's own
156
// `cfg(test)` never provisions, but the pure helpers stay unit-testable).
157

            
158
/// The admin `DATABASE_URL` the server was configured with.
159
///
160
/// # Errors
161
/// [`DBError::MissingUrl`] if the environment variable is unset.
162
351
pub fn admin_database_url() -> Result<String, DBError> {
163
351
    std::env::var("DATABASE_URL").map_err(|_| DBError::MissingUrl)
164
351
}
165

            
166
/// Runs the full migration set against the database at `url`. Used by
167
/// provisioning to bring a freshly-created per-user database up to the
168
/// (DDL-only) schema.
169
///
170
/// # Errors
171
/// [`DBError::Sqlx`] on connect failure, [`DBError::Migration`] if a migration
172
/// fails.
173
243
pub async fn migrate_url(url: &str) -> Result<(), DBError> {
174
243
    let pool = sqlx::postgres::PgPoolOptions::new()
175
243
        .max_connections(1)
176
243
        .acquire_timeout(std::time::Duration::from_secs(10))
177
243
        .connect(url)
178
243
        .await?;
179
243
    let result = sqlx::migrate!("../migrations").run(&pool).await;
180
    // Close the pool on BOTH paths before returning. A still-open connection to
181
    // the per-user DB would otherwise block a compensating `DROP DATABASE` on
182
    // the migration-failure path.
183
243
    pool.close().await;
184
243
    result.map_err(DBError::from)
185
243
}
186

            
187
#[cfg(test)]
188
mod db_tests {
189
    use sqlx::PgPool;
190
    use tokio::sync::OnceCell;
191

            
192
    /// Context for keeping environment intact
193
    static CONTEXT: OnceCell<()> = OnceCell::const_new();
194

            
195
1
    async fn setup() {
196
1
        CONTEXT
197
1
            .get_or_init(|| async {
198
                #[cfg(feature = "testlog")]
199
1
                let _ = env_logger::builder()
200
1
                    .is_test(true)
201
1
                    .filter_level(log::LevelFilter::Trace)
202
1
                    .try_init();
203
2
            })
204
1
            .await;
205
1
    }
206

            
207
    #[sqlx::test(migrations = "../migrations")]
208
    async fn migrations_create_schema_without_seed(pool: PgPool) -> sqlx::Result<()> {
209
        setup().await;
210

            
211
        let mut conn = pool.acquire().await?;
212

            
213
        // The migration set is DDL-only: it creates the `config` table but no
214
        // longer seeds it (seeding moved to `bootstrap::seed`). So a freshly
215
        // migrated DB has the table present and empty.
216
        let config_rows: i64 = sqlx::query_scalar("SELECT count(*) FROM config")
217
            .fetch_one(&mut *conn)
218
            .await?;
219
        assert_eq!(config_rows, 0, "migration set must not seed config rows");
220

            
221
        Ok(())
222
    }
223
}