Lines
62.12 %
Functions
10.48 %
Branches
100 %
use crate::config::config;
use crate::db::DBError;
use sqlx::pool::PoolConnection;
use sqlx::types::Uuid;
use sqlx::{Pool, Postgres, postgres::PgPoolOptions, query_file_scalar};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tokio::sync::OnceCell;
struct UserPool {
pool: Pool<Postgres>,
last_used: Instant,
}
struct PoolRegistry {
pools: Arc<Mutex<HashMap<Uuid, UserPool>>>,
expiration: Duration,
impl PoolRegistry {
fn new(expiration: Duration) -> Self {
Self {
pools: Arc::new(Mutex::new(HashMap::new())),
expiration,
async fn acquire_connection(&self, user_id: Uuid) -> Result<PoolConnection<Postgres>, DBError> {
// First, try to get a connection from an existing pool
let pool_option = {
let pools_arc = Arc::clone(&self.pools);
let mut pools = pools_arc
.lock()
.map_err(|_| DBError::Config(crate::config::ConfigError::DB))?;
// Remove expired pools
pools.retain(|_, user_pool| user_pool.last_used.elapsed() < self.expiration);
// If pool exists, update last_used and return a clone of the pool
if let Some(user_pool) = pools.get_mut(&user_id) {
user_pool.last_used = Instant::now();
Some(user_pool.pool.clone())
} else {
None
};
// Now try to acquire a connection from the cloned pool outside the mutex guard
if let Some(pool) = pool_option {
match pool.acquire().await {
Ok(conn) => return Ok(conn),
Err(e) => {
// If we can't get a connection, log and fall through to recreate the pool
log::warn!(
"Failed to acquire connection for user {user_id} from existing pool: {e}"
);
// If we get here, we need to create a new pool
let url = self.get_db_url_for_user(user_id).await?;
let options = PgPoolOptions::new()
.max_connections(10)
.acquire_timeout(Duration::from_secs(10));
let pool = options.connect(&url).await.map_err(DBError::Sqlx)?;
// Get a connection before storing the pool
let conn = pool.acquire().await.map_err(DBError::Sqlx)?;
// Store the new pool
{
pools.insert(
user_id,
UserPool {
pool: pool.clone(),
last_used: Instant::now(),
},
Ok(conn)
async fn get_db_url_for_user(&self, user_id: Uuid) -> Result<String, DBError> {
let mut conn = crate::db::get_connection().await?;
query_file_scalar!("sql/select/system/db_uid.sql", &user_id)
.fetch_one(&mut *conn)
.await
.map_err(DBError::Sqlx)
static POOL_REGISTRY: OnceCell<PoolRegistry> = OnceCell::const_new();
async fn get_pool_registry() -> &'static PoolRegistry {
POOL_REGISTRY
.get_or_init(|| async move {
let timeout = match config("userregistrytimeout").await {
Ok(Some(value)) => value.to_string().parse().unwrap_or(3600),
_ => 3600,
PoolRegistry::new(Duration::from_secs(timeout))
})
// The main function that clients will call to get a connection
pub(in crate::user) async fn get_connection(
user_id: Uuid,
) -> Result<PoolConnection<Postgres>, DBError> {
let registry = get_pool_registry().await;
registry.acquire_connection(user_id).await