1
use crate::config::config;
2
use crate::db::DBError;
3
use sqlx::pool::PoolConnection;
4
use sqlx::types::Uuid;
5
use sqlx::{Pool, Postgres, postgres::PgPoolOptions, query_file_scalar};
6
use std::collections::HashMap;
7
use std::sync::{Arc, Mutex};
8
use std::time::{Duration, Instant};
9
use tokio::sync::OnceCell;
10

            
11
struct UserPool {
12
    pool: Pool<Postgres>,
13
    last_used: Instant,
14
}
15

            
16
struct PoolRegistry {
17
    pools: Arc<Mutex<HashMap<Uuid, UserPool>>>,
18
    expiration: Duration,
19
}
20

            
21
impl PoolRegistry {
22
6
    fn new(expiration: Duration) -> Self {
23
6
        Self {
24
6
            pools: Arc::new(Mutex::new(HashMap::new())),
25
6
            expiration,
26
6
        }
27
6
    }
28

            
29
238
    async fn acquire_connection(&self, user_id: Uuid) -> Result<PoolConnection<Postgres>, DBError> {
30
        // First, try to get a connection from an existing pool
31
68
        let pool_option = {
32
68
            let pools_arc = Arc::clone(&self.pools);
33
68
            let mut pools = pools_arc
34
68
                .lock()
35
68
                .map_err(|_| DBError::Config(crate::config::ConfigError::DB))?;
36

            
37
            // Remove expired pools
38
68
            pools.retain(|_, user_pool| user_pool.last_used.elapsed() < self.expiration);
39

            
40
            // If pool exists, update last_used and return a clone of the pool
41
68
            if let Some(user_pool) = pools.get_mut(&user_id) {
42
                user_pool.last_used = Instant::now();
43
                Some(user_pool.pool.clone())
44
            } else {
45
68
                None
46
            }
47
        };
48

            
49
        // Now try to acquire a connection from the cloned pool outside the mutex guard
50
68
        if let Some(pool) = pool_option {
51
            match pool.acquire().await {
52
                Ok(conn) => return Ok(conn),
53
                Err(e) => {
54
                    // If we can't get a connection, log and fall through to recreate the pool
55
                    log::warn!(
56
                        "Failed to acquire connection for user {user_id} from existing pool: {e}"
57
                    );
58
                }
59
            }
60
68
        }
61

            
62
        // If we get here, we need to create a new pool
63
68
        let url = self.get_db_url_for_user(user_id).await?;
64

            
65
        let options = PgPoolOptions::new()
66
            .max_connections(10)
67
            .acquire_timeout(Duration::from_secs(10));
68

            
69
        let pool = options.connect(&url).await.map_err(DBError::Sqlx)?;
70

            
71
        // Get a connection before storing the pool
72
        let conn = pool.acquire().await.map_err(DBError::Sqlx)?;
73

            
74
        // Store the new pool
75
        {
76
            let pools_arc = Arc::clone(&self.pools);
77
            let mut pools = pools_arc
78
                .lock()
79
                .map_err(|_| DBError::Config(crate::config::ConfigError::DB))?;
80
            pools.insert(
81
                user_id,
82
                UserPool {
83
                    pool: pool.clone(),
84
                    last_used: Instant::now(),
85
                },
86
            );
87
        }
88

            
89
        Ok(conn)
90
68
    }
91

            
92
238
    async fn get_db_url_for_user(&self, user_id: Uuid) -> Result<String, DBError> {
93
68
        let mut conn = crate::db::get_connection().await?;
94

            
95
68
        query_file_scalar!("sql/select/system/db_uid.sql", &user_id)
96
68
            .fetch_one(&mut *conn)
97
68
            .await
98
68
            .map_err(DBError::Sqlx)
99
68
    }
100
}
101

            
102
static POOL_REGISTRY: OnceCell<PoolRegistry> = OnceCell::const_new();
103

            
104
238
async fn get_pool_registry() -> &'static PoolRegistry {
105
68
    POOL_REGISTRY
106
68
        .get_or_init(|| async move {
107
2
            let timeout = match config("userregistrytimeout").await {
108
                Ok(Some(value)) => value.to_string().parse().unwrap_or(3600),
109
2
                _ => 3600,
110
            };
111
2
            PoolRegistry::new(Duration::from_secs(timeout))
112
4
        })
113
68
        .await
114
68
}
115

            
116
// The main function that clients will call to get a connection
117
204
pub(in crate::user) async fn get_connection(
118
204
    user_id: Uuid,
119
238
) -> Result<PoolConnection<Postgres>, DBError> {
120
68
    let registry = get_pool_registry().await;
121
68
    registry.acquire_connection(user_id).await
122
68
}