1use crate::models::{NewScanEvent, ScanEvent};
2use anyhow::Context;
3use autopulse_utils::sify;
4use diesel::connection::SimpleConnection;
5use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
6use diesel::{Connection, RunQueryDsl};
7use diesel::{SaveChangesDsl, SelectableHelper};
8use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
9use serde::Deserialize;
10use std::fs::OpenOptions;
11use std::path::PathBuf;
12use std::time::{SystemTime, UNIX_EPOCH};
13use tracing::{info, warn};
14
15#[doc(hidden)]
16#[cfg(feature = "postgres")]
17const POSTGRES_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgres");
18
19#[doc(hidden)]
20#[cfg(feature = "sqlite")]
21const SQLITE_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/sqlite");
22
23#[derive(Deserialize, Debug)]
24#[serde(rename_all = "lowercase")]
25#[derive(Default)]
26pub enum DatabaseType {
27 #[cfg(feature = "sqlite")]
28 #[cfg_attr(feature = "sqlite", default)]
29 Sqlite,
30 #[cfg(feature = "postgres")]
31 #[cfg_attr(not(feature = "sqlite"), default)]
32 Postgres,
33}
34
35impl DatabaseType {
36 pub fn default_url(&self) -> String {
37 match self {
38 #[cfg(feature = "sqlite")]
39 Self::Sqlite => "sqlite://data/autopulse.db".to_string(),
40 #[cfg(feature = "postgres")]
41 Self::Postgres => "postgres://autopulse:autopulse@localhost:5432/autopulse".to_string(),
42 }
43 }
44}
45
46#[derive(diesel::MultiConnection)]
48pub enum AnyConnection {
49 #[cfg(feature = "postgres")]
59 Postgresql(diesel::PgConnection),
60 #[cfg(feature = "sqlite")]
81 Sqlite(diesel::SqliteConnection),
82}
83
84#[doc(hidden)]
85#[derive(Debug, Default)]
86pub struct AcquireHook {
87 pub setup: bool,
88}
89
90impl diesel::r2d2::CustomizeConnection<AnyConnection, diesel::r2d2::Error> for AcquireHook {
91 fn on_acquire(&self, conn: &mut AnyConnection) -> Result<(), diesel::r2d2::Error> {
92 (|| {
93 match conn {
94 #[cfg(feature = "sqlite")]
95 AnyConnection::Sqlite(ref mut conn) => {
96 conn.batch_execute("PRAGMA busy_timeout = 5000")?;
97 conn.batch_execute("PRAGMA synchronous = NORMAL;")?;
98 conn.batch_execute("PRAGMA wal_autocheckpoint = 1000;")?;
99 conn.batch_execute("PRAGMA foreign_keys = ON;")?;
100
101 if self.setup {
102 conn.batch_execute("PRAGMA journal_mode = WAL;")?;
103 conn.batch_execute("VACUUM")?;
104 }
105 }
106 #[cfg(feature = "postgres")]
107 AnyConnection::Postgresql(ref mut conn) => {
108 if self.setup {
109 conn.batch_execute("VACUUM ANALYZE")?;
110 }
111 }
112 }
113 Ok(())
114 })()
115 .map_err(diesel::r2d2::Error::QueryError)
116 }
117}
118
119impl AnyConnection {
120 pub fn pre_init(database_url: &str) -> anyhow::Result<()> {
121 if database_url.starts_with("sqlite://") && !database_url.contains(":memory:") {
122 let path = database_url
123 .strip_prefix("sqlite://")
124 .expect("already checked prefix");
125
126 let path = PathBuf::from(path);
127
128 let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) else {
129 return Ok(());
130 };
131
132 if !parent.exists() {
134 std::fs::create_dir_all(parent).with_context(|| {
135 format!("failed to create database directory: {}", parent.display())
136 })?;
137 }
138
139 let timestamp = SystemTime::now()
140 .duration_since(UNIX_EPOCH)
141 .map(|duration| duration.as_nanos())
142 .unwrap_or_default();
143 let probe = parent.join(format!(
144 ".autopulse-db-write-test-{}-{timestamp}",
145 std::process::id()
146 ));
147
148 let file = OpenOptions::new()
149 .write(true)
150 .create_new(true)
151 .open(&probe)
152 .with_context(|| {
153 format!("database directory is not writable: {}", parent.display())
154 })?;
155 drop(file);
156
157 std::fs::remove_file(&probe).with_context(|| {
158 format!(
159 "failed to remove database directory write test file: {}",
160 probe.display()
161 )
162 })?;
163 }
164
165 Ok(())
166 }
167
168 pub fn migrate(&mut self) -> anyhow::Result<()> {
169 let migrations_applied = match self {
170 #[cfg(feature = "postgres")]
171 Self::Postgresql(conn) => conn.run_pending_migrations(POSTGRES_MIGRATIONS),
172 #[cfg(feature = "sqlite")]
173 Self::Sqlite(conn) => conn.run_pending_migrations(SQLITE_MIGRATIONS),
174 }
175 .map_err(|e| anyhow::Error::from_boxed(e).context("failed to run migrations"))?;
178
179 if !migrations_applied.is_empty() {
180 info!(
181 "Applied {} migration{}",
182 migrations_applied.len(),
183 sify(&migrations_applied)
184 );
185 }
186
187 Ok(())
188 }
189
190 pub fn close(&mut self) -> anyhow::Result<()> {
191 match self {
192 #[cfg(feature = "postgres")]
193 Self::Postgresql(_) => {}
194 #[cfg(feature = "sqlite")]
195 Self::Sqlite(conn) => {
196 conn.batch_execute("PRAGMA wal_checkpoint(TRUNCATE);")
198 .context("failed to checkpoint WAL")?;
199 }
200 }
201
202 Ok(())
203 }
204
205 pub fn save_changes(&mut self, ev: &mut ScanEvent) -> anyhow::Result<ScanEvent> {
206 let ev = match self {
207 #[cfg(feature = "postgres")]
208 Self::Postgresql(conn) => ev.save_changes::<ScanEvent>(conn),
209 #[cfg(feature = "sqlite")]
212 Self::Sqlite(conn) => ev.save_changes::<ScanEvent>(conn),
213 }?;
214
215 Ok(ev)
216 }
217
218 pub fn insert_and_return(&mut self, ev: &NewScanEvent) -> anyhow::Result<ScanEvent> {
219 match self {
220 #[cfg(feature = "postgres")]
221 Self::Postgresql(conn) => diesel::insert_into(crate::schema::scan_events::table)
222 .values(ev)
223 .returning(ScanEvent::as_returning())
224 .get_result::<ScanEvent>(conn)
225 .map_err(Into::into),
226 #[cfg(feature = "sqlite")]
227 Self::Sqlite(conn) => diesel::insert_into(crate::schema::scan_events::table)
228 .values(ev)
229 .returning(ScanEvent::as_returning())
230 .get_result::<ScanEvent>(conn)
231 .map_err(Into::into),
232 }
233 }
234
235 pub fn upsert_pending(
237 &mut self,
238 ev: &NewScanEvent,
239 now: chrono::NaiveDateTime,
240 ) -> anyhow::Result<ScanEvent> {
241 match self {
242 #[cfg(feature = "postgres")]
243 Self::Postgresql(conn) => upsert_pending_pg(conn, ev, now),
244 #[cfg(feature = "sqlite")]
245 Self::Sqlite(conn) => upsert_pending_sqlite(conn, ev, now),
246 }
247 }
248}
249
250#[cfg(feature = "postgres")]
251fn upsert_pending_pg(
252 conn: &mut diesel::PgConnection,
253 ev: &NewScanEvent,
254 now: chrono::NaiveDateTime,
255) -> anyhow::Result<ScanEvent> {
256 use crate::models::ProcessStatus;
257 use crate::schema::scan_events::dsl::{
258 can_process, file_hash, file_path, process_status, updated_at,
259 };
260 use diesel::dsl::case_when;
261 use diesel::upsert::{excluded, DecoratableTarget};
262 use diesel::ExpressionMethods;
263
264 let pending: String = ProcessStatus::Pending.into();
267 let retry: String = ProcessStatus::Retry.into();
268
269 diesel::insert_into(crate::schema::scan_events::table)
270 .values(ev)
271 .on_conflict(file_path)
272 .filter_target(process_status.eq_any([pending, retry]))
273 .do_update()
274 .set((
275 updated_at.eq(now),
276 can_process.eq(
277 case_when(can_process.lt(excluded(can_process)), excluded(can_process))
278 .otherwise(can_process),
279 ),
280 file_hash.eq(case_when(file_hash.is_null(), excluded(file_hash)).otherwise(file_hash)),
281 ))
282 .returning(ScanEvent::as_returning())
283 .get_result::<ScanEvent>(conn)
284 .map_err(Into::into)
285}
286
287#[cfg(feature = "sqlite")]
288fn upsert_pending_sqlite(
289 conn: &mut diesel::SqliteConnection,
290 ev: &NewScanEvent,
291 now: chrono::NaiveDateTime,
292) -> anyhow::Result<ScanEvent> {
293 use crate::models::ProcessStatus;
294 use crate::schema::scan_events::dsl::{
295 can_process, file_hash, file_path, process_status, scan_events, updated_at,
296 };
297 use diesel::{ExpressionMethods, QueryDsl};
298 use diesel::{OptionalExtension, SelectableHelper};
299
300 let pending: String = ProcessStatus::Pending.into();
303 let retry: String = ProcessStatus::Retry.into();
304
305 let existing: Option<ScanEvent> = scan_events
306 .filter(file_path.eq(&ev.file_path))
307 .filter(process_status.eq_any([pending, retry]))
308 .first::<ScanEvent>(conn)
309 .optional()?;
310
311 if let Some(existing) = existing {
312 let later_can_process = std::cmp::max(existing.can_process, ev.can_process);
313 let file_hash_value = existing.file_hash.clone().or_else(|| ev.file_hash.clone());
314 diesel::update(&existing)
315 .set((
316 updated_at.eq(now),
317 can_process.eq(later_can_process),
318 file_hash.eq(file_hash_value),
319 ))
320 .get_result::<ScanEvent>(conn)
321 .map_err(Into::into)
322 } else {
323 diesel::insert_into(crate::schema::scan_events::table)
324 .values(ev)
325 .returning(ScanEvent::as_returning())
326 .get_result::<ScanEvent>(conn)
327 .map_err(Into::into)
328 }
329}
330
331#[doc(hidden)]
332pub type DbPool = Pool<ConnectionManager<AnyConnection>>;
333
334#[doc(hidden)]
335pub fn get_conn(
336 pool: &Pool<ConnectionManager<AnyConnection>>,
337) -> anyhow::Result<PooledConnection<ConnectionManager<AnyConnection>>> {
338 pool.get().context("failed to get connection from pool")
339}
340
341pub fn close_pool(pool: &Pool<ConnectionManager<AnyConnection>>) {
342 match pool.get() {
343 Ok(mut conn) => {
344 if let Err(e) = conn.close() {
345 warn!("failed to close database connection cleanly: {e}");
346 }
347 }
348 Err(e) => {
349 warn!("failed to get connection for pool shutdown: {e}");
350 }
351 }
352}
353
354#[doc(hidden)]
355pub fn get_pool(database_url: &String) -> anyhow::Result<Pool<ConnectionManager<AnyConnection>>> {
356 let manager = ConnectionManager::<AnyConnection>::new(database_url);
358
359 let setup_pool = Pool::builder()
360 .max_size(1)
361 .connection_customizer(Box::new(AcquireHook { setup: true }))
362 .build(manager)
363 .context("failed to create setup pool")?;
364
365 drop(setup_pool);
366
367 let manager = ConnectionManager::<AnyConnection>::new(database_url);
368
369 let builder = Pool::builder().connection_customizer(Box::new(AcquireHook::default()));
370
371 #[cfg(feature = "sqlite")]
372 let builder = if database_url.starts_with("sqlite://") {
373 builder.max_size(1)
374 } else {
375 builder
376 };
377
378 builder.build(manager).context("failed to create pool")
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use std::fs;
385 use tempfile::tempdir;
386
387 #[test]
388 fn test_pre_init_memory_db_skipped() {
389 let result = AnyConnection::pre_init("sqlite://:memory:");
390 assert!(result.is_ok());
391 }
392
393 #[test]
394 fn test_pre_init_creates_directory() {
395 let tmp = tempdir().unwrap();
396 let db_path = tmp.path().join("subdir").join("test.db");
397 let url = format!("sqlite://{}", db_path.display());
398
399 let result = AnyConnection::pre_init(&url);
400 assert!(result.is_ok());
401 assert!(db_path.parent().unwrap().exists());
402 }
403
404 #[test]
405 fn test_pre_init_no_parent_directory() {
406 let result = AnyConnection::pre_init("sqlite://test.db");
407 assert!(result.is_ok());
408 }
409
410 #[test]
411 fn test_pre_init_writable_directory_succeeds() {
412 let tmp = tempdir().unwrap();
413 let subdir = tmp.path().join("writable");
414 fs::create_dir(&subdir).unwrap();
415
416 let db_path = subdir.join("test.db");
417 let url = format!("sqlite://{}", db_path.display());
418
419 let result = AnyConnection::pre_init(&url);
420 assert!(result.is_ok());
421 }
422
423 #[cfg(unix)]
424 #[test]
425 fn test_pre_init_existing_unwritable_directory_fails_with_context() {
426 use std::os::unix::fs::PermissionsExt;
427
428 let tmp = tempdir().unwrap();
429 let subdir = tmp.path().join("readonly");
430 fs::create_dir(&subdir).unwrap();
431 fs::set_permissions(&subdir, fs::Permissions::from_mode(0o555)).unwrap();
432
433 let db_path = subdir.join("test.db");
434 let url = format!("sqlite://{}", db_path.display());
435
436 let result = AnyConnection::pre_init(&url);
437
438 fs::set_permissions(&subdir, fs::Permissions::from_mode(0o755)).unwrap();
439 let err = result.expect_err("unwritable database directory should fail pre-init");
440 let err = err.to_string();
441 assert!(err.contains("database directory is not writable"));
442 assert!(err.contains(&subdir.display().to_string()));
443 }
444
445 #[test]
446 fn test_pre_init_postgres_skipped() {
447 let result = AnyConnection::pre_init("postgres://localhost/test");
448 assert!(result.is_ok());
449 }
450
451 #[test]
452 #[cfg(feature = "sqlite")]
453 fn test_close_pool_cleans_up_wal_files() {
454 let tmp = tempdir().unwrap();
455 let db_path = tmp.path().join("test.db");
456 let url = format!("sqlite://{}", db_path.display());
457
458 AnyConnection::pre_init(&url).unwrap();
459 let pool = get_pool(&url).unwrap();
460
461 {
463 let mut conn = get_conn(&pool).unwrap();
464 conn.migrate().unwrap();
465 }
466
467 close_pool(&pool);
469 drop(pool);
470
471 let wal_path = tmp.path().join("test.db-wal");
473 let shm_path = tmp.path().join("test.db-shm");
474 assert!(!wal_path.exists(), "WAL file should be cleaned up");
475 assert!(!shm_path.exists(), "SHM file should be cleaned up");
476 }
477
478 #[test]
479 #[cfg(feature = "sqlite")]
480 fn dedupe_migration_merges_max_can_process_into_survivor() {
481 use crate::models::ProcessStatus;
482 use crate::schema::scan_events::dsl::{file_path, process_status, scan_events};
483 use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
484 use diesel::{ExpressionMethods, QueryDsl, RunQueryDsl};
485
486 let tmp = tempdir().unwrap();
487 let db_path = tmp.path().join("test.db");
488 let url = format!("sqlite://{}", db_path.display());
489
490 AnyConnection::pre_init(&url).unwrap();
491 let pool = get_pool(&url).unwrap();
492 let mut conn = get_conn(&pool).unwrap();
493
494 conn.batch_execute(
495 r#"
496 CREATE TABLE scan_events (
497 id TEXT PRIMARY KEY NOT NULL,
498 event_source TEXT NOT NULL,
499 event_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
500 file_path TEXT NOT NULL,
501 file_hash TEXT,
502 process_status TEXT NOT NULL DEFAULT 'pending',
503 found_status TEXT NOT NULL DEFAULT 'not_found',
504 failed_times INTEGER DEFAULT 0 NOT NULL,
505 next_retry_at TIMESTAMP,
506 targets_hit TEXT DEFAULT '' NOT NULL,
507 found_at TIMESTAMP,
508 processed_at TIMESTAMP,
509 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
510 updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
511 can_process TIMESTAMP NOT NULL DEFAULT "2024-10-14T12:00:00.000"
512 );
513
514 CREATE TABLE __diesel_schema_migrations (
515 version VARCHAR(50) PRIMARY KEY NOT NULL,
516 run_on TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
517 );
518
519 INSERT INTO __diesel_schema_migrations (version) VALUES
520 ('20240829125750'),
521 ('20240905143749'),
522 ('20240906161345'),
523 ('20241012130403'),
524 ('20241205114327'),
525 ('20241205115656'),
526 ('202512300005460000'),
527 ('20260519000001');
528
529 INSERT INTO scan_events (
530 id, event_source, file_path, file_hash, process_status,
531 updated_at, created_at, event_timestamp, can_process
532 ) VALUES
533 (
534 'older-long-wait', 'sonarr', '/media/migrate.mkv', 'sha256:migrate', 'pending',
535 '2026-01-01 00:00:00', '2026-01-01 00:00:00',
536 '2026-01-01 00:00:00', '2026-01-01 03:00:00'
537 ),
538 (
539 'newer-short-wait', 'notify', '/media/migrate.mkv', NULL, 'retry',
540 '2026-01-01 01:00:00', '2026-01-01 01:00:00',
541 '2026-01-01 01:00:00', '2026-01-01 02:00:00'
542 );
543 "#,
544 )
545 .unwrap();
546
547 conn.migrate().unwrap();
548
549 let pending: String = ProcessStatus::Pending.into();
550 let retry: String = ProcessStatus::Retry.into();
551 let rows = scan_events
552 .filter(file_path.eq("/media/migrate.mkv"))
553 .filter(process_status.eq_any([pending, retry]))
554 .load::<ScanEvent>(&mut conn)
555 .unwrap();
556
557 assert_eq!(rows.len(), 1, "migration should leave one non-terminal row");
558 assert_eq!(rows[0].id, "newer-short-wait", "newest row should survive");
559 assert_eq!(
560 rows[0].can_process,
561 NaiveDateTime::new(
562 NaiveDate::from_ymd_opt(2026, 1, 1).unwrap(),
563 NaiveTime::from_hms_opt(3, 0, 0).unwrap(),
564 ),
565 "survivor should inherit the duplicate group's longest wait"
566 );
567 assert_eq!(
568 rows[0].file_hash,
569 Some("sha256:migrate".to_string()),
570 "survivor should inherit a duplicate's hash when it has none"
571 );
572 }
573}