1use crate::config::ConfigError;
2use cfg_if::cfg_if;
3use sqlx::migrate::MigrateError;
4use sqlx::pool::PoolConnection;
5use sqlx::{PgPool, Postgres};
6use thiserror::Error;
7cfg_if! {
8 if #[cfg(test)] {
9 use std::cell::Cell;
10 } else if #[cfg(not(test))] {
11 use sqlx::postgres::PgPoolOptions;
12 use std::env::var;
13 use std::sync::LazyLock;
14 use std::time::Duration;
15 use tokio::sync::OnceCell;
16
17 static DB_URL: LazyLock<String> = LazyLock::new(|| {
18 var("DATABASE_URL")
19 .unwrap_or_else(|_| panic!("{}", String::from(t!("DATABASE_URL is not provided"))))
20 });
21 }
22}
23
24#[derive(Debug, Error)]
25pub enum DBError {
26 #[error("Database error: {0}")]
27 Sqlx(#[from] sqlx::Error),
28 #[error("DB migration error: {0}")]
29 Migration(#[from] MigrateError),
30 #[error("Configuration access error")]
31 Config(#[from] ConfigError),
32}
33
34pub async fn migrate_db() -> Result<(), DBError> {
35 Ok(sqlx::migrate!("../migrations")
36 .run(get_pool().await?)
37 .await?)
38}
39
40pub async fn get_connection() -> Result<PoolConnection<Postgres>, DBError> {
41 Ok(get_pool().await?.acquire().await?)
42}
43
44cfg_if! {
45 if #[cfg(test)] {
46 thread_local! (pub static DB_POOL: Cell<*const PgPool> = panic!("!"));
48
49 async fn get_pool() -> Result<&'static PgPool, DBError> {
50 unsafe {Ok(&*DB_POOL.get())}
51 }
52
53 } else if #[cfg(not(test))] {
54 static DB_POOL: OnceCell<PgPool> = OnceCell::const_new();
56
57 async fn get_pool() -> Result<&'static PgPool, DBError> {
58 Ok(DB_POOL
59 .get_or_init(|| async {
60 log::debug!("Pool initialization");
61 let options = PgPoolOptions::new()
62 .max_connections(10)
63 .acquire_timeout(Duration::from_secs(10));
64 options.connect(&DB_URL).await.unwrap()
65 }).await)
66 }
67 }
68}
69
70#[cfg(test)]
71mod db_tests {
72 use sqlx::PgPool;
73 use tokio::sync::OnceCell;
74
75 static CONTEXT: OnceCell<()> = OnceCell::const_new();
77
78 async fn setup() {
79 CONTEXT
80 .get_or_init(|| async {
81 #[cfg(feature = "testlog")]
82 let _ = env_logger::builder()
83 .is_test(true)
84 .filter_level(log::LevelFilter::Trace)
85 .try_init();
86 })
87 .await;
88 }
89
90 #[sqlx::test(migrations = "../migrations")]
91 async fn migrations_test(pool: PgPool) -> sqlx::Result<()> {
92 setup().await;
93
94 let mut conn = pool.acquire().await?;
95
96 let init: String =
97 sqlx::query_scalar("SELECT contents FROM config WHERE lower(field) = 'initialized'")
98 .fetch_one(&mut *conn)
99 .await?;
100
101 assert_eq!(init, "YES");
102
103 Ok(())
104 }
105}