Lines
85.71 %
Functions
50 %
Branches
100 %
use crate::config::ConfigError;
use cfg_if::cfg_if;
use sqlx::migrate::MigrateError;
use sqlx::pool::PoolConnection;
use sqlx::{PgPool, Postgres};
use thiserror::Error;
cfg_if! {
if #[cfg(test)] {
use std::cell::Cell;
} else if #[cfg(not(test))] {
use sqlx::postgres::PgPoolOptions;
use std::env::var;
use std::sync::LazyLock;
use std::time::Duration;
use tokio::sync::OnceCell;
static DB_URL: LazyLock<String> = LazyLock::new(|| {
var("DATABASE_URL")
.unwrap_or_else(|_| panic!("{}", String::from(t!("DATABASE_URL is not provided"))))
});
}
#[derive(Debug, Error)]
pub enum DBError {
#[error("Database error: {0}")]
Sqlx(#[from] sqlx::Error),
#[error("DB migration error: {0}")]
Migration(#[from] MigrateError),
#[error("Configuration access error")]
Config(#[from] ConfigError),
pub async fn migrate_db() -> Result<(), DBError> {
Ok(sqlx::migrate!("../migrations")
.run(get_pool().await?)
.await?)
pub async fn get_connection() -> Result<PoolConnection<Postgres>, DBError> {
Ok(get_pool().await?.acquire().await?)
// Just a place to store a pointer that can be mocked from inside tests
thread_local! (pub static DB_POOL: Cell<*const PgPool> = panic!("!"));
async fn get_pool() -> Result<&'static PgPool, DBError> {
unsafe {Ok(&*DB_POOL.get())}
// And a correct storage for the run-time
static DB_POOL: OnceCell<PgPool> = OnceCell::const_new();
Ok(DB_POOL
.get_or_init(|| async {
log::debug!("Pool initialization");
let options = PgPoolOptions::new()
.max_connections(10)
.acquire_timeout(Duration::from_secs(10));
options.connect(&DB_URL).await.unwrap()
}).await)
#[cfg(test)]
mod db_tests {
use sqlx::PgPool;
/// Context for keeping environment intact
static CONTEXT: OnceCell<()> = OnceCell::const_new();
async fn setup() {
CONTEXT
#[cfg(feature = "testlog")]
let _ = env_logger::builder()
.is_test(true)
.filter_level(log::LevelFilter::Trace)
.try_init();
})
.await;
#[sqlx::test(migrations = "../migrations")]
async fn migrations_test(pool: PgPool) -> sqlx::Result<()> {
setup().await;
let mut conn = pool.acquire().await?;
let init: String =
sqlx::query_scalar("SELECT contents FROM config WHERE lower(field) = 'initialized'")
.fetch_one(&mut *conn)
.await?;
assert_eq!(init, "YES");
Ok(())