Skip to main content

autopulse_database/
conn.rs

1use crate::models::{NewScanEvent, ScanEvent};
2use anyhow::Context;
3use autopulse_utils::sify;
4use diesel::connection::SimpleConnection;
5use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
6use diesel::{Connection, RunQueryDsl};
7use diesel::{SaveChangesDsl, SelectableHelper};
8use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
9use serde::Deserialize;
10use std::fs::OpenOptions;
11use std::path::PathBuf;
12use std::time::{SystemTime, UNIX_EPOCH};
13use tracing::{info, warn};
14
15#[doc(hidden)]
16#[cfg(feature = "postgres")]
17const POSTGRES_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgres");
18
19#[doc(hidden)]
20#[cfg(feature = "sqlite")]
21const SQLITE_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/sqlite");
22
23#[derive(Deserialize, Debug)]
24#[serde(rename_all = "lowercase")]
25#[derive(Default)]
26pub enum DatabaseType {
27    #[cfg(feature = "sqlite")]
28    #[cfg_attr(feature = "sqlite", default)]
29    Sqlite,
30    #[cfg(feature = "postgres")]
31    #[cfg_attr(not(feature = "sqlite"), default)]
32    Postgres,
33}
34
35impl DatabaseType {
36    pub fn default_url(&self) -> String {
37        match self {
38            #[cfg(feature = "sqlite")]
39            Self::Sqlite => "sqlite://data/autopulse.db".to_string(),
40            #[cfg(feature = "postgres")]
41            Self::Postgres => "postgres://autopulse:autopulse@localhost:5432/autopulse".to_string(),
42        }
43    }
44}
45
46/// Represents a connection to either a `PostgreSQL` or `SQLite` database.
47#[derive(diesel::MultiConnection)]
48pub enum AnyConnection {
49    /// A connection to a `PostgreSQL` database.
50    ///
51    /// This is used when the `database_url` is a `PostgreSQL` URL.
52    ///
53    /// # Example
54    ///
55    /// ```md
56    /// postgres://user:password@localhost:5432/database
57    /// ```
58    #[cfg(feature = "postgres")]
59    Postgresql(diesel::PgConnection),
60    // Mysql(diesel::MysqlConnection),
61    /// A connection to a `SQLite` database.
62    ///
63    /// This is used when the `database_url` is a `SQLite` URL.
64    ///
65    /// Note: The directory where the database is stored will also be populated with a WAL file and a journal file.
66    ///
67    /// # Example
68    ///
69    /// ```bash
70    /// # Relative path
71    /// sqlite://database.db
72    /// sqlite://data/database.db
73    ///
74    /// # Absolute path
75    /// sqlite:///data/database.db
76    ///
77    /// # In-memory database
78    /// sqlite://:memory: # In-memory database
79    /// ```
80    #[cfg(feature = "sqlite")]
81    Sqlite(diesel::SqliteConnection),
82}
83
84#[doc(hidden)]
85#[derive(Debug, Default)]
86pub struct AcquireHook {
87    pub setup: bool,
88}
89
90impl diesel::r2d2::CustomizeConnection<AnyConnection, diesel::r2d2::Error> for AcquireHook {
91    fn on_acquire(&self, conn: &mut AnyConnection) -> Result<(), diesel::r2d2::Error> {
92        (|| {
93            match conn {
94                #[cfg(feature = "sqlite")]
95                AnyConnection::Sqlite(ref mut conn) => {
96                    conn.batch_execute("PRAGMA busy_timeout = 5000")?;
97                    conn.batch_execute("PRAGMA synchronous = NORMAL;")?;
98                    conn.batch_execute("PRAGMA wal_autocheckpoint = 1000;")?;
99                    conn.batch_execute("PRAGMA foreign_keys = ON;")?;
100
101                    if self.setup {
102                        conn.batch_execute("PRAGMA journal_mode = WAL;")?;
103                        conn.batch_execute("VACUUM")?;
104                    }
105                }
106                #[cfg(feature = "postgres")]
107                AnyConnection::Postgresql(ref mut conn) => {
108                    if self.setup {
109                        conn.batch_execute("VACUUM ANALYZE")?;
110                    }
111                }
112            }
113            Ok(())
114        })()
115        .map_err(diesel::r2d2::Error::QueryError)
116    }
117}
118
119impl AnyConnection {
120    pub fn pre_init(database_url: &str) -> anyhow::Result<()> {
121        if database_url.starts_with("sqlite://") && !database_url.contains(":memory:") {
122            let path = database_url
123                .strip_prefix("sqlite://")
124                .expect("already checked prefix");
125
126            let path = PathBuf::from(path);
127
128            let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) else {
129                return Ok(());
130            };
131
132            // Create directory if it doesn't exist
133            if !parent.exists() {
134                std::fs::create_dir_all(parent).with_context(|| {
135                    format!("failed to create database directory: {}", parent.display())
136                })?;
137            }
138
139            let timestamp = SystemTime::now()
140                .duration_since(UNIX_EPOCH)
141                .map(|duration| duration.as_nanos())
142                .unwrap_or_default();
143            let probe = parent.join(format!(
144                ".autopulse-db-write-test-{}-{timestamp}",
145                std::process::id()
146            ));
147
148            let file = OpenOptions::new()
149                .write(true)
150                .create_new(true)
151                .open(&probe)
152                .with_context(|| {
153                    format!("database directory is not writable: {}", parent.display())
154                })?;
155            drop(file);
156
157            std::fs::remove_file(&probe).with_context(|| {
158                format!(
159                    "failed to remove database directory write test file: {}",
160                    probe.display()
161                )
162            })?;
163        }
164
165        Ok(())
166    }
167
168    pub fn migrate(&mut self) -> anyhow::Result<()> {
169        let migrations_applied = match self {
170            #[cfg(feature = "postgres")]
171            Self::Postgresql(conn) => conn.run_pending_migrations(POSTGRES_MIGRATIONS),
172            #[cfg(feature = "sqlite")]
173            Self::Sqlite(conn) => conn.run_pending_migrations(SQLITE_MIGRATIONS),
174        }
175        // Preserve `e.source()` chain via anyhow::Error::from_boxed; the
176        // previous `anyhow!("...{e}")` flattened it to Display only.
177        .map_err(|e| anyhow::Error::from_boxed(e).context("failed to run migrations"))?;
178
179        if !migrations_applied.is_empty() {
180            info!(
181                "Applied {} migration{}",
182                migrations_applied.len(),
183                sify(&migrations_applied)
184            );
185        }
186
187        Ok(())
188    }
189
190    pub fn close(&mut self) -> anyhow::Result<()> {
191        match self {
192            #[cfg(feature = "postgres")]
193            Self::Postgresql(_) => {}
194            #[cfg(feature = "sqlite")]
195            Self::Sqlite(conn) => {
196                // Should cleanup spare wal/shm files
197                conn.batch_execute("PRAGMA wal_checkpoint(TRUNCATE);")
198                    .context("failed to checkpoint WAL")?;
199            }
200        }
201
202        Ok(())
203    }
204
205    pub fn save_changes(&mut self, ev: &mut ScanEvent) -> anyhow::Result<ScanEvent> {
206        let ev = match self {
207            #[cfg(feature = "postgres")]
208            Self::Postgresql(conn) => ev.save_changes::<ScanEvent>(conn),
209            // #[cfg(feature = "mysql")]
210            // AnyConnection::Mysql(conn) => ev.save_changes::<ScanEvent>(conn),
211            #[cfg(feature = "sqlite")]
212            Self::Sqlite(conn) => ev.save_changes::<ScanEvent>(conn),
213        }?;
214
215        Ok(ev)
216    }
217
218    pub fn insert_and_return(&mut self, ev: &NewScanEvent) -> anyhow::Result<ScanEvent> {
219        match self {
220            #[cfg(feature = "postgres")]
221            Self::Postgresql(conn) => diesel::insert_into(crate::schema::scan_events::table)
222                .values(ev)
223                .returning(ScanEvent::as_returning())
224                .get_result::<ScanEvent>(conn)
225                .map_err(Into::into),
226            #[cfg(feature = "sqlite")]
227            Self::Sqlite(conn) => diesel::insert_into(crate::schema::scan_events::table)
228                .values(ev)
229                .returning(ScanEvent::as_returning())
230                .get_result::<ScanEvent>(conn)
231                .map_err(Into::into),
232        }
233    }
234
235    /// Inserts a queued event, or updates the existing pending/retry row for the path.
236    pub fn upsert_pending(
237        &mut self,
238        ev: &NewScanEvent,
239        now: chrono::NaiveDateTime,
240    ) -> anyhow::Result<ScanEvent> {
241        match self {
242            #[cfg(feature = "postgres")]
243            Self::Postgresql(conn) => upsert_pending_pg(conn, ev, now),
244            #[cfg(feature = "sqlite")]
245            Self::Sqlite(conn) => upsert_pending_sqlite(conn, ev, now),
246        }
247    }
248}
249
250#[cfg(feature = "postgres")]
251fn upsert_pending_pg(
252    conn: &mut diesel::PgConnection,
253    ev: &NewScanEvent,
254    now: chrono::NaiveDateTime,
255) -> anyhow::Result<ScanEvent> {
256    use crate::models::ProcessStatus;
257    use crate::schema::scan_events::dsl::{
258        can_process, file_hash, file_path, process_status, updated_at,
259    };
260    use diesel::dsl::case_when;
261    use diesel::upsert::{excluded, DecoratableTarget};
262    use diesel::ExpressionMethods;
263
264    // Keep this predicate aligned with the partial index; Postgres checks that
265    // match at runtime, and the smoke test covers it.
266    let pending: String = ProcessStatus::Pending.into();
267    let retry: String = ProcessStatus::Retry.into();
268
269    diesel::insert_into(crate::schema::scan_events::table)
270        .values(ev)
271        .on_conflict(file_path)
272        .filter_target(process_status.eq_any([pending, retry]))
273        .do_update()
274        .set((
275            updated_at.eq(now),
276            can_process.eq(
277                case_when(can_process.lt(excluded(can_process)), excluded(can_process))
278                    .otherwise(can_process),
279            ),
280            file_hash.eq(case_when(file_hash.is_null(), excluded(file_hash)).otherwise(file_hash)),
281        ))
282        .returning(ScanEvent::as_returning())
283        .get_result::<ScanEvent>(conn)
284        .map_err(Into::into)
285}
286
287#[cfg(feature = "sqlite")]
288fn upsert_pending_sqlite(
289    conn: &mut diesel::SqliteConnection,
290    ev: &NewScanEvent,
291    now: chrono::NaiveDateTime,
292) -> anyhow::Result<ScanEvent> {
293    use crate::models::ProcessStatus;
294    use crate::schema::scan_events::dsl::{
295        can_process, file_hash, file_path, process_status, scan_events, updated_at,
296    };
297    use diesel::{ExpressionMethods, QueryDsl};
298    use diesel::{OptionalExtension, SelectableHelper};
299
300    // Diesel cannot target SQLite partial indexes here. The SQLite pool has one
301    // connection, so this select-and-write sequence is serialized.
302    let pending: String = ProcessStatus::Pending.into();
303    let retry: String = ProcessStatus::Retry.into();
304
305    let existing: Option<ScanEvent> = scan_events
306        .filter(file_path.eq(&ev.file_path))
307        .filter(process_status.eq_any([pending, retry]))
308        .first::<ScanEvent>(conn)
309        .optional()?;
310
311    if let Some(existing) = existing {
312        let later_can_process = std::cmp::max(existing.can_process, ev.can_process);
313        let file_hash_value = existing.file_hash.clone().or_else(|| ev.file_hash.clone());
314        diesel::update(&existing)
315            .set((
316                updated_at.eq(now),
317                can_process.eq(later_can_process),
318                file_hash.eq(file_hash_value),
319            ))
320            .get_result::<ScanEvent>(conn)
321            .map_err(Into::into)
322    } else {
323        diesel::insert_into(crate::schema::scan_events::table)
324            .values(ev)
325            .returning(ScanEvent::as_returning())
326            .get_result::<ScanEvent>(conn)
327            .map_err(Into::into)
328    }
329}
330
331#[doc(hidden)]
332pub type DbPool = Pool<ConnectionManager<AnyConnection>>;
333
334#[doc(hidden)]
335pub fn get_conn(
336    pool: &Pool<ConnectionManager<AnyConnection>>,
337) -> anyhow::Result<PooledConnection<ConnectionManager<AnyConnection>>> {
338    pool.get().context("failed to get connection from pool")
339}
340
341pub fn close_pool(pool: &Pool<ConnectionManager<AnyConnection>>) {
342    match pool.get() {
343        Ok(mut conn) => {
344            if let Err(e) = conn.close() {
345                warn!("failed to close database connection cleanly: {e}");
346            }
347        }
348        Err(e) => {
349            warn!("failed to get connection for pool shutdown: {e}");
350        }
351    }
352}
353
354#[doc(hidden)]
355pub fn get_pool(database_url: &String) -> anyhow::Result<Pool<ConnectionManager<AnyConnection>>> {
356    // First pool fires `AcquireHook { setup: true }` once (VACUUM/WAL), then dropped.
357    let manager = ConnectionManager::<AnyConnection>::new(database_url);
358
359    let setup_pool = Pool::builder()
360        .max_size(1)
361        .connection_customizer(Box::new(AcquireHook { setup: true }))
362        .build(manager)
363        .context("failed to create setup pool")?;
364
365    drop(setup_pool);
366
367    let manager = ConnectionManager::<AnyConnection>::new(database_url);
368
369    let builder = Pool::builder().connection_customizer(Box::new(AcquireHook::default()));
370
371    #[cfg(feature = "sqlite")]
372    let builder = if database_url.starts_with("sqlite://") {
373        builder.max_size(1)
374    } else {
375        builder
376    };
377
378    builder.build(manager).context("failed to create pool")
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use std::fs;
385    use tempfile::tempdir;
386
387    #[test]
388    fn test_pre_init_memory_db_skipped() {
389        let result = AnyConnection::pre_init("sqlite://:memory:");
390        assert!(result.is_ok());
391    }
392
393    #[test]
394    fn test_pre_init_creates_directory() {
395        let tmp = tempdir().unwrap();
396        let db_path = tmp.path().join("subdir").join("test.db");
397        let url = format!("sqlite://{}", db_path.display());
398
399        let result = AnyConnection::pre_init(&url);
400        assert!(result.is_ok());
401        assert!(db_path.parent().unwrap().exists());
402    }
403
404    #[test]
405    fn test_pre_init_no_parent_directory() {
406        let result = AnyConnection::pre_init("sqlite://test.db");
407        assert!(result.is_ok());
408    }
409
410    #[test]
411    fn test_pre_init_writable_directory_succeeds() {
412        let tmp = tempdir().unwrap();
413        let subdir = tmp.path().join("writable");
414        fs::create_dir(&subdir).unwrap();
415
416        let db_path = subdir.join("test.db");
417        let url = format!("sqlite://{}", db_path.display());
418
419        let result = AnyConnection::pre_init(&url);
420        assert!(result.is_ok());
421    }
422
423    #[cfg(unix)]
424    #[test]
425    fn test_pre_init_existing_unwritable_directory_fails_with_context() {
426        use std::os::unix::fs::PermissionsExt;
427
428        let tmp = tempdir().unwrap();
429        let subdir = tmp.path().join("readonly");
430        fs::create_dir(&subdir).unwrap();
431        fs::set_permissions(&subdir, fs::Permissions::from_mode(0o555)).unwrap();
432
433        let db_path = subdir.join("test.db");
434        let url = format!("sqlite://{}", db_path.display());
435
436        let result = AnyConnection::pre_init(&url);
437
438        fs::set_permissions(&subdir, fs::Permissions::from_mode(0o755)).unwrap();
439        let err = result.expect_err("unwritable database directory should fail pre-init");
440        let err = err.to_string();
441        assert!(err.contains("database directory is not writable"));
442        assert!(err.contains(&subdir.display().to_string()));
443    }
444
445    #[test]
446    fn test_pre_init_postgres_skipped() {
447        let result = AnyConnection::pre_init("postgres://localhost/test");
448        assert!(result.is_ok());
449    }
450
451    #[test]
452    #[cfg(feature = "sqlite")]
453    fn test_close_pool_cleans_up_wal_files() {
454        let tmp = tempdir().unwrap();
455        let db_path = tmp.path().join("test.db");
456        let url = format!("sqlite://{}", db_path.display());
457
458        AnyConnection::pre_init(&url).unwrap();
459        let pool = get_pool(&url).unwrap();
460
461        // Get a connection to trigger WAL mode and create the db
462        {
463            let mut conn = get_conn(&pool).unwrap();
464            conn.migrate().unwrap();
465        }
466
467        // WAL files may exist at this point
468        close_pool(&pool);
469        drop(pool);
470
471        // Verify no WAL files remain
472        let wal_path = tmp.path().join("test.db-wal");
473        let shm_path = tmp.path().join("test.db-shm");
474        assert!(!wal_path.exists(), "WAL file should be cleaned up");
475        assert!(!shm_path.exists(), "SHM file should be cleaned up");
476    }
477
478    #[test]
479    #[cfg(feature = "sqlite")]
480    fn dedupe_migration_merges_max_can_process_into_survivor() {
481        use crate::models::ProcessStatus;
482        use crate::schema::scan_events::dsl::{file_path, process_status, scan_events};
483        use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
484        use diesel::{ExpressionMethods, QueryDsl, RunQueryDsl};
485
486        let tmp = tempdir().unwrap();
487        let db_path = tmp.path().join("test.db");
488        let url = format!("sqlite://{}", db_path.display());
489
490        AnyConnection::pre_init(&url).unwrap();
491        let pool = get_pool(&url).unwrap();
492        let mut conn = get_conn(&pool).unwrap();
493
494        conn.batch_execute(
495            r#"
496            CREATE TABLE scan_events (
497                id TEXT PRIMARY KEY NOT NULL,
498                event_source TEXT NOT NULL,
499                event_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
500                file_path TEXT NOT NULL,
501                file_hash TEXT,
502                process_status TEXT NOT NULL DEFAULT 'pending',
503                found_status TEXT NOT NULL DEFAULT 'not_found',
504                failed_times INTEGER DEFAULT 0 NOT NULL,
505                next_retry_at TIMESTAMP,
506                targets_hit TEXT DEFAULT '' NOT NULL,
507                found_at TIMESTAMP,
508                processed_at TIMESTAMP,
509                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
510                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
511                can_process TIMESTAMP NOT NULL DEFAULT "2024-10-14T12:00:00.000"
512            );
513
514            CREATE TABLE __diesel_schema_migrations (
515                version VARCHAR(50) PRIMARY KEY NOT NULL,
516                run_on TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
517            );
518
519            INSERT INTO __diesel_schema_migrations (version) VALUES
520                ('20240829125750'),
521                ('20240905143749'),
522                ('20240906161345'),
523                ('20241012130403'),
524                ('20241205114327'),
525                ('20241205115656'),
526                ('202512300005460000'),
527                ('20260519000001');
528
529            INSERT INTO scan_events (
530                id, event_source, file_path, file_hash, process_status,
531                updated_at, created_at, event_timestamp, can_process
532            ) VALUES
533                (
534                    'older-long-wait', 'sonarr', '/media/migrate.mkv', 'sha256:migrate', 'pending',
535                    '2026-01-01 00:00:00', '2026-01-01 00:00:00',
536                    '2026-01-01 00:00:00', '2026-01-01 03:00:00'
537                ),
538                (
539                    'newer-short-wait', 'notify', '/media/migrate.mkv', NULL, 'retry',
540                    '2026-01-01 01:00:00', '2026-01-01 01:00:00',
541                    '2026-01-01 01:00:00', '2026-01-01 02:00:00'
542                );
543            "#,
544        )
545        .unwrap();
546
547        conn.migrate().unwrap();
548
549        let pending: String = ProcessStatus::Pending.into();
550        let retry: String = ProcessStatus::Retry.into();
551        let rows = scan_events
552            .filter(file_path.eq("/media/migrate.mkv"))
553            .filter(process_status.eq_any([pending, retry]))
554            .load::<ScanEvent>(&mut conn)
555            .unwrap();
556
557        assert_eq!(rows.len(), 1, "migration should leave one non-terminal row");
558        assert_eq!(rows[0].id, "newer-short-wait", "newest row should survive");
559        assert_eq!(
560            rows[0].can_process,
561            NaiveDateTime::new(
562                NaiveDate::from_ymd_opt(2026, 1, 1).unwrap(),
563                NaiveTime::from_hms_opt(3, 0, 0).unwrap(),
564            ),
565            "survivor should inherit the duplicate group's longest wait"
566        );
567        assert_eq!(
568            rows[0].file_hash,
569            Some("sha256:migrate".to_string()),
570            "survivor should inherit a duplicate's hash when it has none"
571        );
572    }
573}