autopulse_database/
conn.rs

1use crate::models::{NewScanEvent, ScanEvent};
2use anyhow::Context;
3use autopulse_utils::sify;
4use diesel::connection::SimpleConnection;
5use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
6use diesel::{Connection, RunQueryDsl};
7use diesel::{SaveChangesDsl, SelectableHelper};
8use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
9use serde::Deserialize;
10use std::path::PathBuf;
11use tracing::info;
12
13#[doc(hidden)]
14#[cfg(feature = "postgres")]
15const POSTGRES_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgres");
16
17#[doc(hidden)]
18#[cfg(feature = "sqlite")]
19const SQLITE_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/sqlite");
20
21#[derive(Deserialize, Debug)]
22#[serde(rename_all = "lowercase")]
23#[derive(Default)]
24pub enum DatabaseType {
25    #[cfg(feature = "sqlite")]
26    #[cfg_attr(feature = "sqlite", default)]
27    Sqlite,
28    #[cfg(feature = "postgres")]
29    #[cfg_attr(not(feature = "sqlite"), default)]
30    Postgres,
31}
32
33impl DatabaseType {
34    pub fn default_url(&self) -> String {
35        match self {
36            #[cfg(feature = "sqlite")]
37            Self::Sqlite => "sqlite://data/autopulse.db".to_string(),
38            #[cfg(feature = "postgres")]
39            Self::Postgres => "postgres://autopulse:autopulse@localhost:5432/autopulse".to_string(),
40        }
41    }
42}
43
44/// Represents a connection to either a `PostgreSQL` or `SQLite` database.
45#[derive(diesel::MultiConnection)]
46pub enum AnyConnection {
47    /// A connection to a `PostgreSQL` database.
48    ///
49    /// This is used when the `database_url` is a `PostgreSQL` URL.
50    ///
51    /// # Example
52    ///
53    /// ```md
54    /// postgres://user:password@localhost:5432/database
55    /// ```
56    #[cfg(feature = "postgres")]
57    Postgresql(diesel::PgConnection),
58    // Mysql(diesel::MysqlConnection),
59    /// A connection to a `SQLite` database.
60    ///
61    /// This is used when the `database_url` is a `SQLite` URL.
62    ///
63    /// Note: The directory where the database is stored will also be populated with a WAL file and a journal file.
64    ///
65    /// # Example
66    ///
67    /// ```bash
68    /// # Relative path
69    /// sqlite://database.db
70    /// sqlite://data/database.db
71    ///
72    /// # Absolute path
73    /// sqlite:///data/database.db
74    ///
75    /// # In-memory database
76    /// sqlite://:memory: # In-memory database
77    /// ```
78    #[cfg(feature = "sqlite")]
79    Sqlite(diesel::SqliteConnection),
80}
81
82#[doc(hidden)]
83#[derive(Debug, Default)]
84pub struct AcquireHook {
85    pub setup: bool,
86}
87
88impl diesel::r2d2::CustomizeConnection<AnyConnection, diesel::r2d2::Error> for AcquireHook {
89    fn on_acquire(&self, conn: &mut AnyConnection) -> Result<(), diesel::r2d2::Error> {
90        (|| {
91            match conn {
92                #[cfg(feature = "sqlite")]
93                AnyConnection::Sqlite(ref mut conn) => {
94                    conn.batch_execute("PRAGMA busy_timeout = 5000")?;
95                    conn.batch_execute("PRAGMA synchronous = NORMAL;")?;
96                    conn.batch_execute("PRAGMA wal_autocheckpoint = 1000;")?;
97                    conn.batch_execute("PRAGMA foreign_keys = ON;")?;
98
99                    if self.setup {
100                        conn.batch_execute("PRAGMA journal_mode = WAL;")?;
101                        conn.batch_execute("VACUUM")?;
102                    }
103                }
104                #[cfg(feature = "postgres")]
105                AnyConnection::Postgresql(ref mut conn) => {
106                    if self.setup {
107                        conn.batch_execute("VACUUM ANALYZE")?;
108                    }
109                }
110            }
111            Ok(())
112        })()
113        .map_err(diesel::r2d2::Error::QueryError)
114    }
115}
116
117impl AnyConnection {
118    pub fn pre_init(database_url: &str) -> anyhow::Result<()> {
119        if database_url.starts_with("sqlite://") && !database_url.contains(":memory:") {
120            let path = database_url
121                .strip_prefix("sqlite://")
122                .expect("already checked prefix");
123
124            let path = PathBuf::from(path);
125
126            let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) else {
127                return Ok(());
128            };
129
130            // Create directory if it doesn't exist
131            if !parent.exists() {
132                std::fs::create_dir_all(parent).with_context(|| {
133                    format!("failed to create database directory: {}", parent.display())
134                })?;
135            }
136        }
137
138        Ok(())
139    }
140
141    pub fn migrate(&mut self) -> anyhow::Result<()> {
142        let migrations_applied = match self {
143            #[cfg(feature = "postgres")]
144            Self::Postgresql(conn) => conn.run_pending_migrations(POSTGRES_MIGRATIONS),
145            #[cfg(feature = "sqlite")]
146            Self::Sqlite(conn) => conn.run_pending_migrations(SQLITE_MIGRATIONS),
147        }
148        .expect("Could not run migrations");
149
150        if !migrations_applied.is_empty() {
151            info!(
152                "Applied {} migration{}",
153                migrations_applied.len(),
154                sify(&migrations_applied)
155            );
156        }
157
158        Ok(())
159    }
160
161    pub fn close(&mut self) -> anyhow::Result<()> {
162        match self {
163            #[cfg(feature = "postgres")]
164            Self::Postgresql(_) => {}
165            #[cfg(feature = "sqlite")]
166            Self::Sqlite(conn) => {
167                // Should cleanup spare wal/shm files
168                conn.batch_execute("PRAGMA wal_checkpoint(TRUNCATE);")
169                    .context("failed to checkpoint WAL")?;
170            }
171        }
172
173        Ok(())
174    }
175
176    pub fn save_changes(&mut self, ev: &mut ScanEvent) -> anyhow::Result<ScanEvent> {
177        let ev = match self {
178            #[cfg(feature = "postgres")]
179            Self::Postgresql(conn) => ev.save_changes::<ScanEvent>(conn),
180            // #[cfg(feature = "mysql")]
181            // AnyConnection::Mysql(conn) => ev.save_changes::<ScanEvent>(conn),
182            #[cfg(feature = "sqlite")]
183            Self::Sqlite(conn) => ev.save_changes::<ScanEvent>(conn),
184        }?;
185
186        Ok(ev)
187    }
188
189    pub fn insert_and_return(&mut self, ev: &NewScanEvent) -> anyhow::Result<ScanEvent> {
190        match self {
191            #[cfg(feature = "postgres")]
192            Self::Postgresql(conn) => diesel::insert_into(crate::schema::scan_events::table)
193                .values(ev)
194                .returning(ScanEvent::as_returning())
195                .get_result::<ScanEvent>(conn)
196                .map_err(Into::into),
197            #[cfg(feature = "sqlite")]
198            Self::Sqlite(conn) => diesel::insert_into(crate::schema::scan_events::table)
199                .values(ev)
200                .returning(ScanEvent::as_returning())
201                .get_result::<ScanEvent>(conn)
202                .map_err(Into::into),
203        }
204    }
205}
206
207#[doc(hidden)]
208pub type DbPool = Pool<ConnectionManager<AnyConnection>>;
209
210#[doc(hidden)]
211pub fn get_conn(
212    pool: &Pool<ConnectionManager<AnyConnection>>,
213) -> anyhow::Result<PooledConnection<ConnectionManager<AnyConnection>>> {
214    pool.get().context("failed to get connection from pool")
215}
216
217pub fn close_pool(pool: &Pool<ConnectionManager<AnyConnection>>) {
218    if let Ok(mut conn) = pool.get() {
219        let _ = conn.close();
220    }
221}
222
223#[doc(hidden)]
224pub fn get_pool(database_url: &String) -> anyhow::Result<Pool<ConnectionManager<AnyConnection>>> {
225    let manager = ConnectionManager::<AnyConnection>::new(database_url);
226
227    let pool = Pool::builder()
228        .max_size(1)
229        .connection_customizer(Box::new(AcquireHook { setup: true }))
230        .build(manager)
231        .context("failed to create pool");
232
233    drop(pool);
234
235    let manager = ConnectionManager::<AnyConnection>::new(database_url);
236
237    Pool::builder()
238        .connection_customizer(Box::new(AcquireHook::default()))
239        .build(manager)
240        .context("failed to create pool")
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use std::fs;
247    use tempfile::tempdir;
248
249    #[test]
250    fn test_pre_init_memory_db_skipped() {
251        let result = AnyConnection::pre_init("sqlite://:memory:");
252        assert!(result.is_ok());
253    }
254
255    #[test]
256    fn test_pre_init_creates_directory() {
257        let tmp = tempdir().unwrap();
258        let db_path = tmp.path().join("subdir").join("test.db");
259        let url = format!("sqlite://{}", db_path.display());
260
261        let result = AnyConnection::pre_init(&url);
262        assert!(result.is_ok());
263        assert!(db_path.parent().unwrap().exists());
264    }
265
266    #[test]
267    fn test_pre_init_no_parent_directory() {
268        let result = AnyConnection::pre_init("sqlite://test.db");
269        assert!(result.is_ok());
270    }
271
272    #[test]
273    fn test_pre_init_writable_directory_succeeds() {
274        let tmp = tempdir().unwrap();
275        let subdir = tmp.path().join("writable");
276        fs::create_dir(&subdir).unwrap();
277
278        let db_path = subdir.join("test.db");
279        let url = format!("sqlite://{}", db_path.display());
280
281        let result = AnyConnection::pre_init(&url);
282        assert!(result.is_ok());
283    }
284
285    #[test]
286    fn test_pre_init_postgres_skipped() {
287        let result = AnyConnection::pre_init("postgres://localhost/test");
288        assert!(result.is_ok());
289    }
290
291    #[test]
292    #[cfg(feature = "sqlite")]
293    fn test_close_pool_cleans_up_wal_files() {
294        let tmp = tempdir().unwrap();
295        let db_path = tmp.path().join("test.db");
296        let url = format!("sqlite://{}", db_path.display());
297
298        AnyConnection::pre_init(&url).unwrap();
299        let pool = get_pool(&url).unwrap();
300
301        // Get a connection to trigger WAL mode and create the db
302        {
303            let mut conn = get_conn(&pool).unwrap();
304            conn.migrate().unwrap();
305        }
306
307        // WAL files may exist at this point
308        close_pool(&pool);
309        drop(pool);
310
311        // Verify no WAL files remain
312        let wal_path = tmp.path().join("test.db-wal");
313        let shm_path = tmp.path().join("test.db-shm");
314        assert!(!wal_path.exists(), "WAL file should be cleaned up");
315        assert!(!shm_path.exists(), "SHM file should be cleaned up");
316    }
317}