Skip to main content

autopulse_database/
conn.rs

1use crate::models::{NewScanEvent, ScanEvent};
2use anyhow::Context;
3use autopulse_utils::sify;
4use diesel::connection::SimpleConnection;
5use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
6use diesel::{Connection, RunQueryDsl};
7use diesel::{SaveChangesDsl, SelectableHelper};
8use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
9use serde::Deserialize;
10use std::path::PathBuf;
11use tracing::{info, warn};
12
13#[doc(hidden)]
14#[cfg(feature = "postgres")]
15const POSTGRES_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgres");
16
17#[doc(hidden)]
18#[cfg(feature = "sqlite")]
19const SQLITE_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/sqlite");
20
21#[derive(Deserialize, Debug)]
22#[serde(rename_all = "lowercase")]
23#[derive(Default)]
24pub enum DatabaseType {
25    #[cfg(feature = "sqlite")]
26    #[cfg_attr(feature = "sqlite", default)]
27    Sqlite,
28    #[cfg(feature = "postgres")]
29    #[cfg_attr(not(feature = "sqlite"), default)]
30    Postgres,
31}
32
33impl DatabaseType {
34    pub fn default_url(&self) -> String {
35        match self {
36            #[cfg(feature = "sqlite")]
37            Self::Sqlite => "sqlite://data/autopulse.db".to_string(),
38            #[cfg(feature = "postgres")]
39            Self::Postgres => "postgres://autopulse:autopulse@localhost:5432/autopulse".to_string(),
40        }
41    }
42}
43
44/// Represents a connection to either a `PostgreSQL` or `SQLite` database.
45#[derive(diesel::MultiConnection)]
46pub enum AnyConnection {
47    /// A connection to a `PostgreSQL` database.
48    ///
49    /// This is used when the `database_url` is a `PostgreSQL` URL.
50    ///
51    /// # Example
52    ///
53    /// ```md
54    /// postgres://user:password@localhost:5432/database
55    /// ```
56    #[cfg(feature = "postgres")]
57    Postgresql(diesel::PgConnection),
58    // Mysql(diesel::MysqlConnection),
59    /// A connection to a `SQLite` database.
60    ///
61    /// This is used when the `database_url` is a `SQLite` URL.
62    ///
63    /// Note: The directory where the database is stored will also be populated with a WAL file and a journal file.
64    ///
65    /// # Example
66    ///
67    /// ```bash
68    /// # Relative path
69    /// sqlite://database.db
70    /// sqlite://data/database.db
71    ///
72    /// # Absolute path
73    /// sqlite:///data/database.db
74    ///
75    /// # In-memory database
76    /// sqlite://:memory: # In-memory database
77    /// ```
78    #[cfg(feature = "sqlite")]
79    Sqlite(diesel::SqliteConnection),
80}
81
82#[doc(hidden)]
83#[derive(Debug, Default)]
84pub struct AcquireHook {
85    pub setup: bool,
86}
87
88impl diesel::r2d2::CustomizeConnection<AnyConnection, diesel::r2d2::Error> for AcquireHook {
89    fn on_acquire(&self, conn: &mut AnyConnection) -> Result<(), diesel::r2d2::Error> {
90        (|| {
91            match conn {
92                #[cfg(feature = "sqlite")]
93                AnyConnection::Sqlite(ref mut conn) => {
94                    conn.batch_execute("PRAGMA busy_timeout = 5000")?;
95                    conn.batch_execute("PRAGMA synchronous = NORMAL;")?;
96                    conn.batch_execute("PRAGMA wal_autocheckpoint = 1000;")?;
97                    conn.batch_execute("PRAGMA foreign_keys = ON;")?;
98
99                    if self.setup {
100                        conn.batch_execute("PRAGMA journal_mode = WAL;")?;
101                        conn.batch_execute("VACUUM")?;
102                    }
103                }
104                #[cfg(feature = "postgres")]
105                AnyConnection::Postgresql(ref mut conn) => {
106                    if self.setup {
107                        conn.batch_execute("VACUUM ANALYZE")?;
108                    }
109                }
110            }
111            Ok(())
112        })()
113        .map_err(diesel::r2d2::Error::QueryError)
114    }
115}
116
117impl AnyConnection {
118    pub fn pre_init(database_url: &str) -> anyhow::Result<()> {
119        if database_url.starts_with("sqlite://") && !database_url.contains(":memory:") {
120            let path = database_url
121                .strip_prefix("sqlite://")
122                .expect("already checked prefix");
123
124            let path = PathBuf::from(path);
125
126            let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) else {
127                return Ok(());
128            };
129
130            // Create directory if it doesn't exist
131            if !parent.exists() {
132                std::fs::create_dir_all(parent).with_context(|| {
133                    format!("failed to create database directory: {}", parent.display())
134                })?;
135            }
136        }
137
138        Ok(())
139    }
140
141    pub fn migrate(&mut self) -> anyhow::Result<()> {
142        let migrations_applied = match self {
143            #[cfg(feature = "postgres")]
144            Self::Postgresql(conn) => conn.run_pending_migrations(POSTGRES_MIGRATIONS),
145            #[cfg(feature = "sqlite")]
146            Self::Sqlite(conn) => conn.run_pending_migrations(SQLITE_MIGRATIONS),
147        }
148        // Preserve `e.source()` chain via anyhow::Error::from_boxed; the
149        // previous `anyhow!("...{e}")` flattened it to Display only.
150        .map_err(|e| anyhow::Error::from_boxed(e).context("failed to run migrations"))?;
151
152        if !migrations_applied.is_empty() {
153            info!(
154                "Applied {} migration{}",
155                migrations_applied.len(),
156                sify(&migrations_applied)
157            );
158        }
159
160        Ok(())
161    }
162
163    pub fn close(&mut self) -> anyhow::Result<()> {
164        match self {
165            #[cfg(feature = "postgres")]
166            Self::Postgresql(_) => {}
167            #[cfg(feature = "sqlite")]
168            Self::Sqlite(conn) => {
169                // Should cleanup spare wal/shm files
170                conn.batch_execute("PRAGMA wal_checkpoint(TRUNCATE);")
171                    .context("failed to checkpoint WAL")?;
172            }
173        }
174
175        Ok(())
176    }
177
178    pub fn save_changes(&mut self, ev: &mut ScanEvent) -> anyhow::Result<ScanEvent> {
179        let ev = match self {
180            #[cfg(feature = "postgres")]
181            Self::Postgresql(conn) => ev.save_changes::<ScanEvent>(conn),
182            // #[cfg(feature = "mysql")]
183            // AnyConnection::Mysql(conn) => ev.save_changes::<ScanEvent>(conn),
184            #[cfg(feature = "sqlite")]
185            Self::Sqlite(conn) => ev.save_changes::<ScanEvent>(conn),
186        }?;
187
188        Ok(ev)
189    }
190
191    pub fn insert_and_return(&mut self, ev: &NewScanEvent) -> anyhow::Result<ScanEvent> {
192        match self {
193            #[cfg(feature = "postgres")]
194            Self::Postgresql(conn) => diesel::insert_into(crate::schema::scan_events::table)
195                .values(ev)
196                .returning(ScanEvent::as_returning())
197                .get_result::<ScanEvent>(conn)
198                .map_err(Into::into),
199            #[cfg(feature = "sqlite")]
200            Self::Sqlite(conn) => diesel::insert_into(crate::schema::scan_events::table)
201                .values(ev)
202                .returning(ScanEvent::as_returning())
203                .get_result::<ScanEvent>(conn)
204                .map_err(Into::into),
205        }
206    }
207
208    /// Inserts a queued event, or updates the existing pending/retry row for the path.
209    pub fn upsert_pending(
210        &mut self,
211        ev: &NewScanEvent,
212        now: chrono::NaiveDateTime,
213    ) -> anyhow::Result<ScanEvent> {
214        match self {
215            #[cfg(feature = "postgres")]
216            Self::Postgresql(conn) => upsert_pending_pg(conn, ev, now),
217            #[cfg(feature = "sqlite")]
218            Self::Sqlite(conn) => upsert_pending_sqlite(conn, ev, now),
219        }
220    }
221}
222
223#[cfg(feature = "postgres")]
224fn upsert_pending_pg(
225    conn: &mut diesel::PgConnection,
226    ev: &NewScanEvent,
227    now: chrono::NaiveDateTime,
228) -> anyhow::Result<ScanEvent> {
229    use crate::models::ProcessStatus;
230    use crate::schema::scan_events::dsl::{
231        can_process, file_hash, file_path, process_status, updated_at,
232    };
233    use diesel::dsl::case_when;
234    use diesel::upsert::{excluded, DecoratableTarget};
235    use diesel::ExpressionMethods;
236
237    // Keep this predicate aligned with the partial index; Postgres checks that
238    // match at runtime, and the smoke test covers it.
239    let pending: String = ProcessStatus::Pending.into();
240    let retry: String = ProcessStatus::Retry.into();
241
242    diesel::insert_into(crate::schema::scan_events::table)
243        .values(ev)
244        .on_conflict(file_path)
245        .filter_target(process_status.eq_any([pending, retry]))
246        .do_update()
247        .set((
248            updated_at.eq(now),
249            can_process.eq(
250                case_when(can_process.lt(excluded(can_process)), excluded(can_process))
251                    .otherwise(can_process),
252            ),
253            file_hash.eq(case_when(file_hash.is_null(), excluded(file_hash)).otherwise(file_hash)),
254        ))
255        .returning(ScanEvent::as_returning())
256        .get_result::<ScanEvent>(conn)
257        .map_err(Into::into)
258}
259
260#[cfg(feature = "sqlite")]
261fn upsert_pending_sqlite(
262    conn: &mut diesel::SqliteConnection,
263    ev: &NewScanEvent,
264    now: chrono::NaiveDateTime,
265) -> anyhow::Result<ScanEvent> {
266    use crate::models::ProcessStatus;
267    use crate::schema::scan_events::dsl::{
268        can_process, file_hash, file_path, process_status, scan_events, updated_at,
269    };
270    use diesel::{ExpressionMethods, QueryDsl};
271    use diesel::{OptionalExtension, SelectableHelper};
272
273    // Diesel cannot target SQLite partial indexes here. The SQLite pool has one
274    // connection, so this select-and-write sequence is serialized.
275    let pending: String = ProcessStatus::Pending.into();
276    let retry: String = ProcessStatus::Retry.into();
277
278    let existing: Option<ScanEvent> = scan_events
279        .filter(file_path.eq(&ev.file_path))
280        .filter(process_status.eq_any([pending, retry]))
281        .first::<ScanEvent>(conn)
282        .optional()?;
283
284    if let Some(existing) = existing {
285        let later_can_process = std::cmp::max(existing.can_process, ev.can_process);
286        let file_hash_value = existing.file_hash.clone().or_else(|| ev.file_hash.clone());
287        diesel::update(&existing)
288            .set((
289                updated_at.eq(now),
290                can_process.eq(later_can_process),
291                file_hash.eq(file_hash_value),
292            ))
293            .get_result::<ScanEvent>(conn)
294            .map_err(Into::into)
295    } else {
296        diesel::insert_into(crate::schema::scan_events::table)
297            .values(ev)
298            .returning(ScanEvent::as_returning())
299            .get_result::<ScanEvent>(conn)
300            .map_err(Into::into)
301    }
302}
303
304#[doc(hidden)]
305pub type DbPool = Pool<ConnectionManager<AnyConnection>>;
306
307#[doc(hidden)]
308pub fn get_conn(
309    pool: &Pool<ConnectionManager<AnyConnection>>,
310) -> anyhow::Result<PooledConnection<ConnectionManager<AnyConnection>>> {
311    pool.get().context("failed to get connection from pool")
312}
313
314pub fn close_pool(pool: &Pool<ConnectionManager<AnyConnection>>) {
315    match pool.get() {
316        Ok(mut conn) => {
317            if let Err(e) = conn.close() {
318                warn!("failed to close database connection cleanly: {e}");
319            }
320        }
321        Err(e) => {
322            warn!("failed to get connection for pool shutdown: {e}");
323        }
324    }
325}
326
327#[doc(hidden)]
328pub fn get_pool(database_url: &String) -> anyhow::Result<Pool<ConnectionManager<AnyConnection>>> {
329    // First pool fires `AcquireHook { setup: true }` once (VACUUM/WAL), then dropped.
330    let manager = ConnectionManager::<AnyConnection>::new(database_url);
331
332    let setup_pool = Pool::builder()
333        .max_size(1)
334        .connection_customizer(Box::new(AcquireHook { setup: true }))
335        .build(manager)
336        .context("failed to create setup pool")?;
337
338    drop(setup_pool);
339
340    let manager = ConnectionManager::<AnyConnection>::new(database_url);
341
342    let builder = Pool::builder().connection_customizer(Box::new(AcquireHook::default()));
343
344    #[cfg(feature = "sqlite")]
345    let builder = if database_url.starts_with("sqlite://") {
346        builder.max_size(1)
347    } else {
348        builder
349    };
350
351    builder.build(manager).context("failed to create pool")
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use std::fs;
358    use tempfile::tempdir;
359
360    #[test]
361    fn test_pre_init_memory_db_skipped() {
362        let result = AnyConnection::pre_init("sqlite://:memory:");
363        assert!(result.is_ok());
364    }
365
366    #[test]
367    fn test_pre_init_creates_directory() {
368        let tmp = tempdir().unwrap();
369        let db_path = tmp.path().join("subdir").join("test.db");
370        let url = format!("sqlite://{}", db_path.display());
371
372        let result = AnyConnection::pre_init(&url);
373        assert!(result.is_ok());
374        assert!(db_path.parent().unwrap().exists());
375    }
376
377    #[test]
378    fn test_pre_init_no_parent_directory() {
379        let result = AnyConnection::pre_init("sqlite://test.db");
380        assert!(result.is_ok());
381    }
382
383    #[test]
384    fn test_pre_init_writable_directory_succeeds() {
385        let tmp = tempdir().unwrap();
386        let subdir = tmp.path().join("writable");
387        fs::create_dir(&subdir).unwrap();
388
389        let db_path = subdir.join("test.db");
390        let url = format!("sqlite://{}", db_path.display());
391
392        let result = AnyConnection::pre_init(&url);
393        assert!(result.is_ok());
394    }
395
396    #[test]
397    fn test_pre_init_postgres_skipped() {
398        let result = AnyConnection::pre_init("postgres://localhost/test");
399        assert!(result.is_ok());
400    }
401
402    #[test]
403    #[cfg(feature = "sqlite")]
404    fn test_close_pool_cleans_up_wal_files() {
405        let tmp = tempdir().unwrap();
406        let db_path = tmp.path().join("test.db");
407        let url = format!("sqlite://{}", db_path.display());
408
409        AnyConnection::pre_init(&url).unwrap();
410        let pool = get_pool(&url).unwrap();
411
412        // Get a connection to trigger WAL mode and create the db
413        {
414            let mut conn = get_conn(&pool).unwrap();
415            conn.migrate().unwrap();
416        }
417
418        // WAL files may exist at this point
419        close_pool(&pool);
420        drop(pool);
421
422        // Verify no WAL files remain
423        let wal_path = tmp.path().join("test.db-wal");
424        let shm_path = tmp.path().join("test.db-shm");
425        assert!(!wal_path.exists(), "WAL file should be cleaned up");
426        assert!(!shm_path.exists(), "SHM file should be cleaned up");
427    }
428
429    #[test]
430    #[cfg(feature = "sqlite")]
431    fn dedupe_migration_merges_max_can_process_into_survivor() {
432        use crate::models::ProcessStatus;
433        use crate::schema::scan_events::dsl::{file_path, process_status, scan_events};
434        use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
435        use diesel::{ExpressionMethods, QueryDsl, RunQueryDsl};
436
437        let tmp = tempdir().unwrap();
438        let db_path = tmp.path().join("test.db");
439        let url = format!("sqlite://{}", db_path.display());
440
441        AnyConnection::pre_init(&url).unwrap();
442        let pool = get_pool(&url).unwrap();
443        let mut conn = get_conn(&pool).unwrap();
444
445        conn.batch_execute(
446            r#"
447            CREATE TABLE scan_events (
448                id TEXT PRIMARY KEY NOT NULL,
449                event_source TEXT NOT NULL,
450                event_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
451                file_path TEXT NOT NULL,
452                file_hash TEXT,
453                process_status TEXT NOT NULL DEFAULT 'pending',
454                found_status TEXT NOT NULL DEFAULT 'not_found',
455                failed_times INTEGER DEFAULT 0 NOT NULL,
456                next_retry_at TIMESTAMP,
457                targets_hit TEXT DEFAULT '' NOT NULL,
458                found_at TIMESTAMP,
459                processed_at TIMESTAMP,
460                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
461                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
462                can_process TIMESTAMP NOT NULL DEFAULT "2024-10-14T12:00:00.000"
463            );
464
465            CREATE TABLE __diesel_schema_migrations (
466                version VARCHAR(50) PRIMARY KEY NOT NULL,
467                run_on TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
468            );
469
470            INSERT INTO __diesel_schema_migrations (version) VALUES
471                ('20240829125750'),
472                ('20240905143749'),
473                ('20240906161345'),
474                ('20241012130403'),
475                ('20241205114327'),
476                ('20241205115656'),
477                ('202512300005460000'),
478                ('20260519000001');
479
480            INSERT INTO scan_events (
481                id, event_source, file_path, file_hash, process_status,
482                updated_at, created_at, event_timestamp, can_process
483            ) VALUES
484                (
485                    'older-long-wait', 'sonarr', '/media/migrate.mkv', 'sha256:migrate', 'pending',
486                    '2026-01-01 00:00:00', '2026-01-01 00:00:00',
487                    '2026-01-01 00:00:00', '2026-01-01 03:00:00'
488                ),
489                (
490                    'newer-short-wait', 'notify', '/media/migrate.mkv', NULL, 'retry',
491                    '2026-01-01 01:00:00', '2026-01-01 01:00:00',
492                    '2026-01-01 01:00:00', '2026-01-01 02:00:00'
493                );
494            "#,
495        )
496        .unwrap();
497
498        conn.migrate().unwrap();
499
500        let pending: String = ProcessStatus::Pending.into();
501        let retry: String = ProcessStatus::Retry.into();
502        let rows = scan_events
503            .filter(file_path.eq("/media/migrate.mkv"))
504            .filter(process_status.eq_any([pending, retry]))
505            .load::<ScanEvent>(&mut conn)
506            .unwrap();
507
508        assert_eq!(rows.len(), 1, "migration should leave one non-terminal row");
509        assert_eq!(rows[0].id, "newer-short-wait", "newest row should survive");
510        assert_eq!(
511            rows[0].can_process,
512            NaiveDateTime::new(
513                NaiveDate::from_ymd_opt(2026, 1, 1).unwrap(),
514                NaiveTime::from_hms_opt(3, 0, 0).unwrap(),
515            ),
516            "survivor should inherit the duplicate group's longest wait"
517        );
518        assert_eq!(
519            rows[0].file_hash,
520            Some("sha256:migrate".to_string()),
521            "survivor should inherit a duplicate's hash when it has none"
522        );
523    }
524}