1
use crate::config::ConfigError;
2
use cfg_if::cfg_if;
3
use sqlx::migrate::MigrateError;
4
use sqlx::pool::PoolConnection;
5
use sqlx::{PgPool, Postgres};
6
use thiserror::Error;
7
cfg_if! {
8
    if #[cfg(test)] {
9
    use std::cell::Cell;
10
    } else if #[cfg(not(test))] {
11
    use sqlx::postgres::PgPoolOptions;
12
    use std::env::var;
13
    use std::sync::LazyLock;
14
    use std::time::Duration;
15
    use tokio::sync::OnceCell;
16

            
17
7
    static DB_URL: LazyLock<String> = LazyLock::new(|| {
18
7
        var("DATABASE_URL")
19
7
        .unwrap_or_else(|_| panic!("{}", String::from(t!("DATABASE_URL is not provided"))))
20
7
    });
21
    }
22
}
23

            
24
#[derive(Debug, Error)]
25
pub enum DBError {
26
    #[error("Database error: {0}")]
27
    Sqlx(#[from] sqlx::Error),
28
    #[error("DB migration error: {0}")]
29
    Migration(#[from] MigrateError),
30
    #[error("Configuration access error")]
31
    Config(#[from] ConfigError),
32
}
33

            
34
pub async fn migrate_db() -> Result<(), DBError> {
35
    Ok(sqlx::migrate!("../migrations")
36
        .run(get_pool().await?)
37
        .await?)
38
}
39

            
40
1569
pub async fn get_connection() -> Result<PoolConnection<Postgres>, DBError> {
41
1106
    Ok(get_pool().await?.acquire().await?)
42
1106
}
43

            
44
cfg_if! {
45
    if #[cfg(test)] {
46
    // Just a place to store a pointer that can be mocked from inside tests
47
    thread_local! (pub static DB_POOL: Cell<*const PgPool> = panic!("!"));
48

            
49
1281
    async fn get_pool() -> Result<&'static PgPool, DBError> {
50
854
     unsafe {Ok(&*DB_POOL.get())}
51
854
    }
52

            
53
    } else if #[cfg(not(test))] {
54
    // And a correct storage for the run-time
55
    static DB_POOL: OnceCell<PgPool> = OnceCell::const_new();
56

            
57
288
    async fn get_pool() -> Result<&'static PgPool, DBError> {
58
252
            Ok(DB_POOL
59
252
               .get_or_init(|| async {
60
7
           log::debug!("Pool initialization");
61
7
           let options = PgPoolOptions::new()
62
7
                       .max_connections(10)
63
7
                       .acquire_timeout(Duration::from_secs(10));
64
7
           options.connect(&DB_URL).await.unwrap()
65
252
           }).await)
66
252
    }
67
    }
68
}
69

            
70
#[cfg(test)]
71
mod db_tests {
72
    use sqlx::PgPool;
73
    use tokio::sync::OnceCell;
74

            
75
    /// Context for keeping environment intact
76
    static CONTEXT: OnceCell<()> = OnceCell::const_new();
77

            
78
3
    async fn setup() {
79
2
        CONTEXT
80
2
            .get_or_init(|| async {
81
                #[cfg(feature = "testlog")]
82
2
                let _ = env_logger::builder()
83
2
                    .is_test(true)
84
2
                    .filter_level(log::LevelFilter::Trace)
85
2
                    .try_init();
86
4
            })
87
2
            .await;
88
2
    }
89

            
90
    #[sqlx::test(migrations = "../migrations")]
91
    async fn migrations_test(pool: PgPool) -> sqlx::Result<()> {
92
        setup().await;
93

            
94
        let mut conn = pool.acquire().await?;
95

            
96
        let init: String =
97
            sqlx::query_scalar("SELECT contents FROM config WHERE lower(field) = 'initialized'")
98
                .fetch_one(&mut *conn)
99
                .await?;
100

            
101
        assert_eq!(init, "YES");
102

            
103
        Ok(())
104
    }
105
}