Skip to main content

server/
db.rs

1use crate::config::ConfigError;
2use cfg_if::cfg_if;
3use sqlx::migrate::MigrateError;
4use sqlx::pool::PoolConnection;
5use sqlx::{PgPool, Postgres};
6use thiserror::Error;
7cfg_if! {
8    if #[cfg(test)] {
9    use std::cell::Cell;
10    } else if #[cfg(not(test))] {
11    use sqlx::postgres::PgPoolOptions;
12    use std::env::var;
13    use std::sync::LazyLock;
14    use std::time::Duration;
15    use tokio::sync::OnceCell;
16
17    static DB_URL: LazyLock<String> = LazyLock::new(|| {
18        var("DATABASE_URL")
19        .unwrap_or_else(|_| panic!("{}", String::from(t!("DATABASE_URL is not provided"))))
20    });
21    }
22}
23
24#[derive(Debug, Error)]
25pub enum DBError {
26    #[error("Database error: {0}")]
27    Sqlx(#[from] sqlx::Error),
28    #[error("DB migration error: {0}")]
29    Migration(#[from] MigrateError),
30    #[error("Configuration access error")]
31    Config(#[from] ConfigError),
32}
33
34pub async fn migrate_db() -> Result<(), DBError> {
35    Ok(sqlx::migrate!("../migrations")
36        .run(get_pool().await?)
37        .await?)
38}
39
40pub async fn get_connection() -> Result<PoolConnection<Postgres>, DBError> {
41    Ok(get_pool().await?.acquire().await?)
42}
43
44cfg_if! {
45    if #[cfg(test)] {
46    // Just a place to store a pointer that can be mocked from inside tests
47    thread_local! (pub static DB_POOL: Cell<*const PgPool> = panic!("!"));
48
49    async fn get_pool() -> Result<&'static PgPool, DBError> {
50     unsafe {Ok(&*DB_POOL.get())}
51    }
52
53    } else if #[cfg(not(test))] {
54    // And a correct storage for the run-time
55    static DB_POOL: OnceCell<PgPool> = OnceCell::const_new();
56
57    async fn get_pool() -> Result<&'static PgPool, DBError> {
58            Ok(DB_POOL
59               .get_or_init(|| async {
60           log::debug!("Pool initialization");
61           let options = PgPoolOptions::new()
62                       .max_connections(10)
63                       .acquire_timeout(Duration::from_secs(10));
64           options.connect(&DB_URL).await.unwrap()
65           }).await)
66    }
67    }
68}
69
70#[cfg(test)]
71mod db_tests {
72    use sqlx::PgPool;
73    use tokio::sync::OnceCell;
74
75    /// Context for keeping environment intact
76    static CONTEXT: OnceCell<()> = OnceCell::const_new();
77
78    async fn setup() {
79        CONTEXT
80            .get_or_init(|| async {
81                #[cfg(feature = "testlog")]
82                let _ = env_logger::builder()
83                    .is_test(true)
84                    .filter_level(log::LevelFilter::Trace)
85                    .try_init();
86            })
87            .await;
88    }
89
90    #[sqlx::test(migrations = "../migrations")]
91    async fn migrations_test(pool: PgPool) -> sqlx::Result<()> {
92        setup().await;
93
94        let mut conn = pool.acquire().await?;
95
96        let init: String =
97            sqlx::query_scalar("SELECT contents FROM config WHERE lower(field) = 'initialized'")
98                .fetch_one(&mut *conn)
99                .await?;
100
101        assert_eq!(init, "YES");
102
103        Ok(())
104    }
105}