1use crate::config::ConfigError;
2use cfg_if::cfg_if;
3use sqlx::migrate::MigrateError;
4use sqlx::pool::PoolConnection;
5use sqlx::{PgPool, Postgres};
6use thiserror::Error;
7cfg_if! {
8 if #[cfg(test)] {
9 use std::cell::Cell;
10 } else if #[cfg(feature = "test-utils")] {
11 use std::cell::Cell;
12 use sqlx::postgres::PgPoolOptions;
13 use std::env::var;
14 use std::sync::LazyLock;
15 use std::time::Duration;
16 use tokio::sync::OnceCell;
17
18 static DB_URL: LazyLock<String> = LazyLock::new(|| {
19 var("DATABASE_URL")
20 .unwrap_or_else(|_| panic!("{}", String::from(t!("DATABASE_URL is not provided"))))
21 });
22 } else {
23 use sqlx::postgres::PgPoolOptions;
24 use std::env::var;
25 use std::sync::LazyLock;
26 use std::time::Duration;
27 use tokio::sync::OnceCell;
28
29 static DB_URL: LazyLock<String> = LazyLock::new(|| {
30 var("DATABASE_URL")
31 .unwrap_or_else(|_| panic!("{}", String::from(t!("DATABASE_URL is not provided"))))
32 });
33 }
34}
35
36#[derive(Debug, Error)]
37pub enum DBError {
38 #[error("Database error: {0}")]
39 Sqlx(#[from] sqlx::Error),
40 #[error("DB migration error: {0}")]
41 Migration(#[from] MigrateError),
42 #[error("Configuration access error")]
43 Config(#[from] ConfigError),
44 #[error("DATABASE_URL is not provided")]
45 MissingUrl,
46 #[error("The database role lacks CREATEDB privilege")]
47 NoCreateDb,
48 #[error("Failed to generate an authentication keypair")]
49 KeyGen,
50}
51
52pub async fn migrate_db() -> Result<(), DBError> {
53 Ok(sqlx::migrate!("../migrations")
54 .run(get_pool().await?)
55 .await?)
56}
57
58pub async fn get_connection() -> Result<PoolConnection<Postgres>, DBError> {
59 Ok(get_pool().await?.acquire().await?)
60}
61
62pub async fn execute_raw(sql: &str) -> Result<(), DBError> {
67 sqlx::raw_sql(sql).execute(get_pool().await?).await?;
68 Ok(())
69}
70
71cfg_if! {
72 if #[cfg(test)] {
73 thread_local!(pub static DB_POOL: Cell<*const PgPool> = const {
79 Cell::new(std::ptr::null())
80 });
81
82 #[must_use]
86 pub fn test_pool_is_set() -> bool {
87 DB_POOL.with(|c| !c.get().is_null())
88 }
89
90 async fn get_pool() -> Result<&'static PgPool, DBError> {
91 let p = DB_POOL.with(|c| c.get());
92 assert!(!p.is_null(), "DB_POOL must be set; see local_db_sqlx_test macro");
93 unsafe { Ok(&*p) }
94 }
95
96 } else if #[cfg(feature = "test-utils")] {
97 thread_local!(pub static DB_POOL: Cell<*const PgPool> = const {
105 Cell::new(std::ptr::null())
106 });
107
108 #[must_use]
112 pub fn test_pool_is_set() -> bool {
113 DB_POOL.with(|c| !c.get().is_null())
114 }
115
116 static FALLBACK_POOL: OnceCell<PgPool> = OnceCell::const_new();
117
118 async fn get_pool() -> Result<&'static PgPool, DBError> {
119 let p = DB_POOL.with(|c| c.get());
120 if !p.is_null() {
121 return unsafe { Ok(&*p) };
122 }
123 Ok(FALLBACK_POOL
124 .get_or_init(|| async {
125 log::debug!("Fallback pool initialization");
126 let options = PgPoolOptions::new()
127 .max_connections(10)
128 .acquire_timeout(Duration::from_secs(10));
129 options.connect(&DB_URL).await.unwrap()
130 }).await)
131 }
132
133 } else {
134 static DB_POOL: OnceCell<PgPool> = OnceCell::const_new();
136
137 async fn get_pool() -> Result<&'static PgPool, DBError> {
138 Ok(DB_POOL
139 .get_or_init(|| async {
140 log::debug!("Pool initialization");
141 let options = PgPoolOptions::new()
142 .max_connections(10)
143 .acquire_timeout(Duration::from_secs(10));
144 options.connect(&DB_URL).await.unwrap()
145 }).await)
146 }
147 }
148}
149
150pub fn admin_database_url() -> Result<String, DBError> {
163 std::env::var("DATABASE_URL").map_err(|_| DBError::MissingUrl)
164}
165
166pub async fn migrate_url(url: &str) -> Result<(), DBError> {
174 let pool = sqlx::postgres::PgPoolOptions::new()
175 .max_connections(1)
176 .acquire_timeout(std::time::Duration::from_secs(10))
177 .connect(url)
178 .await?;
179 let result = sqlx::migrate!("../migrations").run(&pool).await;
180 pool.close().await;
184 result.map_err(DBError::from)
185}
186
187#[cfg(test)]
188mod db_tests {
189 use sqlx::PgPool;
190 use tokio::sync::OnceCell;
191
192 static CONTEXT: OnceCell<()> = OnceCell::const_new();
194
195 async fn setup() {
196 CONTEXT
197 .get_or_init(|| async {
198 #[cfg(feature = "testlog")]
199 let _ = env_logger::builder()
200 .is_test(true)
201 .filter_level(log::LevelFilter::Trace)
202 .try_init();
203 })
204 .await;
205 }
206
207 #[sqlx::test(migrations = "../migrations")]
208 async fn migrations_create_schema_without_seed(pool: PgPool) -> sqlx::Result<()> {
209 setup().await;
210
211 let mut conn = pool.acquire().await?;
212
213 let config_rows: i64 = sqlx::query_scalar("SELECT count(*) FROM config")
217 .fetch_one(&mut *conn)
218 .await?;
219 assert_eq!(config_rows, 0, "migration set must not seed config rows");
220
221 Ok(())
222 }
223}