1use crate::models::{NewScanEvent, ScanEvent};
2use anyhow::Context;
3use autopulse_utils::sify;
4use diesel::connection::SimpleConnection;
5use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
6use diesel::{Connection, RunQueryDsl};
7use diesel::{SaveChangesDsl, SelectableHelper};
8use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
9use serde::Deserialize;
10use std::path::PathBuf;
11use tracing::{info, warn};
12
13#[doc(hidden)]
14#[cfg(feature = "postgres")]
15const POSTGRES_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgres");
16
17#[doc(hidden)]
18#[cfg(feature = "sqlite")]
19const SQLITE_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/sqlite");
20
21#[derive(Deserialize, Debug)]
22#[serde(rename_all = "lowercase")]
23#[derive(Default)]
24pub enum DatabaseType {
25 #[cfg(feature = "sqlite")]
26 #[cfg_attr(feature = "sqlite", default)]
27 Sqlite,
28 #[cfg(feature = "postgres")]
29 #[cfg_attr(not(feature = "sqlite"), default)]
30 Postgres,
31}
32
33impl DatabaseType {
34 pub fn default_url(&self) -> String {
35 match self {
36 #[cfg(feature = "sqlite")]
37 Self::Sqlite => "sqlite://data/autopulse.db".to_string(),
38 #[cfg(feature = "postgres")]
39 Self::Postgres => "postgres://autopulse:autopulse@localhost:5432/autopulse".to_string(),
40 }
41 }
42}
43
44#[derive(diesel::MultiConnection)]
46pub enum AnyConnection {
47 #[cfg(feature = "postgres")]
57 Postgresql(diesel::PgConnection),
58 #[cfg(feature = "sqlite")]
79 Sqlite(diesel::SqliteConnection),
80}
81
82#[doc(hidden)]
83#[derive(Debug, Default)]
84pub struct AcquireHook {
85 pub setup: bool,
86}
87
88impl diesel::r2d2::CustomizeConnection<AnyConnection, diesel::r2d2::Error> for AcquireHook {
89 fn on_acquire(&self, conn: &mut AnyConnection) -> Result<(), diesel::r2d2::Error> {
90 (|| {
91 match conn {
92 #[cfg(feature = "sqlite")]
93 AnyConnection::Sqlite(ref mut conn) => {
94 conn.batch_execute("PRAGMA busy_timeout = 5000")?;
95 conn.batch_execute("PRAGMA synchronous = NORMAL;")?;
96 conn.batch_execute("PRAGMA wal_autocheckpoint = 1000;")?;
97 conn.batch_execute("PRAGMA foreign_keys = ON;")?;
98
99 if self.setup {
100 conn.batch_execute("PRAGMA journal_mode = WAL;")?;
101 conn.batch_execute("VACUUM")?;
102 }
103 }
104 #[cfg(feature = "postgres")]
105 AnyConnection::Postgresql(ref mut conn) => {
106 if self.setup {
107 conn.batch_execute("VACUUM ANALYZE")?;
108 }
109 }
110 }
111 Ok(())
112 })()
113 .map_err(diesel::r2d2::Error::QueryError)
114 }
115}
116
117impl AnyConnection {
118 pub fn pre_init(database_url: &str) -> anyhow::Result<()> {
119 if database_url.starts_with("sqlite://") && !database_url.contains(":memory:") {
120 let path = database_url
121 .strip_prefix("sqlite://")
122 .expect("already checked prefix");
123
124 let path = PathBuf::from(path);
125
126 let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) else {
127 return Ok(());
128 };
129
130 if !parent.exists() {
132 std::fs::create_dir_all(parent).with_context(|| {
133 format!("failed to create database directory: {}", parent.display())
134 })?;
135 }
136 }
137
138 Ok(())
139 }
140
141 pub fn migrate(&mut self) -> anyhow::Result<()> {
142 let migrations_applied = match self {
143 #[cfg(feature = "postgres")]
144 Self::Postgresql(conn) => conn.run_pending_migrations(POSTGRES_MIGRATIONS),
145 #[cfg(feature = "sqlite")]
146 Self::Sqlite(conn) => conn.run_pending_migrations(SQLITE_MIGRATIONS),
147 }
148 .map_err(|e| anyhow::Error::from_boxed(e).context("failed to run migrations"))?;
151
152 if !migrations_applied.is_empty() {
153 info!(
154 "Applied {} migration{}",
155 migrations_applied.len(),
156 sify(&migrations_applied)
157 );
158 }
159
160 Ok(())
161 }
162
163 pub fn close(&mut self) -> anyhow::Result<()> {
164 match self {
165 #[cfg(feature = "postgres")]
166 Self::Postgresql(_) => {}
167 #[cfg(feature = "sqlite")]
168 Self::Sqlite(conn) => {
169 conn.batch_execute("PRAGMA wal_checkpoint(TRUNCATE);")
171 .context("failed to checkpoint WAL")?;
172 }
173 }
174
175 Ok(())
176 }
177
178 pub fn save_changes(&mut self, ev: &mut ScanEvent) -> anyhow::Result<ScanEvent> {
179 let ev = match self {
180 #[cfg(feature = "postgres")]
181 Self::Postgresql(conn) => ev.save_changes::<ScanEvent>(conn),
182 #[cfg(feature = "sqlite")]
185 Self::Sqlite(conn) => ev.save_changes::<ScanEvent>(conn),
186 }?;
187
188 Ok(ev)
189 }
190
191 pub fn insert_and_return(&mut self, ev: &NewScanEvent) -> anyhow::Result<ScanEvent> {
192 match self {
193 #[cfg(feature = "postgres")]
194 Self::Postgresql(conn) => diesel::insert_into(crate::schema::scan_events::table)
195 .values(ev)
196 .returning(ScanEvent::as_returning())
197 .get_result::<ScanEvent>(conn)
198 .map_err(Into::into),
199 #[cfg(feature = "sqlite")]
200 Self::Sqlite(conn) => diesel::insert_into(crate::schema::scan_events::table)
201 .values(ev)
202 .returning(ScanEvent::as_returning())
203 .get_result::<ScanEvent>(conn)
204 .map_err(Into::into),
205 }
206 }
207
208 pub fn upsert_pending(
210 &mut self,
211 ev: &NewScanEvent,
212 now: chrono::NaiveDateTime,
213 ) -> anyhow::Result<ScanEvent> {
214 match self {
215 #[cfg(feature = "postgres")]
216 Self::Postgresql(conn) => upsert_pending_pg(conn, ev, now),
217 #[cfg(feature = "sqlite")]
218 Self::Sqlite(conn) => upsert_pending_sqlite(conn, ev, now),
219 }
220 }
221}
222
223#[cfg(feature = "postgres")]
224fn upsert_pending_pg(
225 conn: &mut diesel::PgConnection,
226 ev: &NewScanEvent,
227 now: chrono::NaiveDateTime,
228) -> anyhow::Result<ScanEvent> {
229 use crate::models::ProcessStatus;
230 use crate::schema::scan_events::dsl::{
231 can_process, file_hash, file_path, process_status, updated_at,
232 };
233 use diesel::dsl::case_when;
234 use diesel::upsert::{excluded, DecoratableTarget};
235 use diesel::ExpressionMethods;
236
237 let pending: String = ProcessStatus::Pending.into();
240 let retry: String = ProcessStatus::Retry.into();
241
242 diesel::insert_into(crate::schema::scan_events::table)
243 .values(ev)
244 .on_conflict(file_path)
245 .filter_target(process_status.eq_any([pending, retry]))
246 .do_update()
247 .set((
248 updated_at.eq(now),
249 can_process.eq(
250 case_when(can_process.lt(excluded(can_process)), excluded(can_process))
251 .otherwise(can_process),
252 ),
253 file_hash.eq(case_when(file_hash.is_null(), excluded(file_hash)).otherwise(file_hash)),
254 ))
255 .returning(ScanEvent::as_returning())
256 .get_result::<ScanEvent>(conn)
257 .map_err(Into::into)
258}
259
260#[cfg(feature = "sqlite")]
261fn upsert_pending_sqlite(
262 conn: &mut diesel::SqliteConnection,
263 ev: &NewScanEvent,
264 now: chrono::NaiveDateTime,
265) -> anyhow::Result<ScanEvent> {
266 use crate::models::ProcessStatus;
267 use crate::schema::scan_events::dsl::{
268 can_process, file_hash, file_path, process_status, scan_events, updated_at,
269 };
270 use diesel::{ExpressionMethods, QueryDsl};
271 use diesel::{OptionalExtension, SelectableHelper};
272
273 let pending: String = ProcessStatus::Pending.into();
276 let retry: String = ProcessStatus::Retry.into();
277
278 let existing: Option<ScanEvent> = scan_events
279 .filter(file_path.eq(&ev.file_path))
280 .filter(process_status.eq_any([pending, retry]))
281 .first::<ScanEvent>(conn)
282 .optional()?;
283
284 if let Some(existing) = existing {
285 let later_can_process = std::cmp::max(existing.can_process, ev.can_process);
286 let file_hash_value = existing.file_hash.clone().or_else(|| ev.file_hash.clone());
287 diesel::update(&existing)
288 .set((
289 updated_at.eq(now),
290 can_process.eq(later_can_process),
291 file_hash.eq(file_hash_value),
292 ))
293 .get_result::<ScanEvent>(conn)
294 .map_err(Into::into)
295 } else {
296 diesel::insert_into(crate::schema::scan_events::table)
297 .values(ev)
298 .returning(ScanEvent::as_returning())
299 .get_result::<ScanEvent>(conn)
300 .map_err(Into::into)
301 }
302}
303
304#[doc(hidden)]
305pub type DbPool = Pool<ConnectionManager<AnyConnection>>;
306
307#[doc(hidden)]
308pub fn get_conn(
309 pool: &Pool<ConnectionManager<AnyConnection>>,
310) -> anyhow::Result<PooledConnection<ConnectionManager<AnyConnection>>> {
311 pool.get().context("failed to get connection from pool")
312}
313
314pub fn close_pool(pool: &Pool<ConnectionManager<AnyConnection>>) {
315 match pool.get() {
316 Ok(mut conn) => {
317 if let Err(e) = conn.close() {
318 warn!("failed to close database connection cleanly: {e}");
319 }
320 }
321 Err(e) => {
322 warn!("failed to get connection for pool shutdown: {e}");
323 }
324 }
325}
326
327#[doc(hidden)]
328pub fn get_pool(database_url: &String) -> anyhow::Result<Pool<ConnectionManager<AnyConnection>>> {
329 let manager = ConnectionManager::<AnyConnection>::new(database_url);
331
332 let setup_pool = Pool::builder()
333 .max_size(1)
334 .connection_customizer(Box::new(AcquireHook { setup: true }))
335 .build(manager)
336 .context("failed to create setup pool")?;
337
338 drop(setup_pool);
339
340 let manager = ConnectionManager::<AnyConnection>::new(database_url);
341
342 let builder = Pool::builder().connection_customizer(Box::new(AcquireHook::default()));
343
344 #[cfg(feature = "sqlite")]
345 let builder = if database_url.starts_with("sqlite://") {
346 builder.max_size(1)
347 } else {
348 builder
349 };
350
351 builder.build(manager).context("failed to create pool")
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use std::fs;
358 use tempfile::tempdir;
359
360 #[test]
361 fn test_pre_init_memory_db_skipped() {
362 let result = AnyConnection::pre_init("sqlite://:memory:");
363 assert!(result.is_ok());
364 }
365
366 #[test]
367 fn test_pre_init_creates_directory() {
368 let tmp = tempdir().unwrap();
369 let db_path = tmp.path().join("subdir").join("test.db");
370 let url = format!("sqlite://{}", db_path.display());
371
372 let result = AnyConnection::pre_init(&url);
373 assert!(result.is_ok());
374 assert!(db_path.parent().unwrap().exists());
375 }
376
377 #[test]
378 fn test_pre_init_no_parent_directory() {
379 let result = AnyConnection::pre_init("sqlite://test.db");
380 assert!(result.is_ok());
381 }
382
383 #[test]
384 fn test_pre_init_writable_directory_succeeds() {
385 let tmp = tempdir().unwrap();
386 let subdir = tmp.path().join("writable");
387 fs::create_dir(&subdir).unwrap();
388
389 let db_path = subdir.join("test.db");
390 let url = format!("sqlite://{}", db_path.display());
391
392 let result = AnyConnection::pre_init(&url);
393 assert!(result.is_ok());
394 }
395
396 #[test]
397 fn test_pre_init_postgres_skipped() {
398 let result = AnyConnection::pre_init("postgres://localhost/test");
399 assert!(result.is_ok());
400 }
401
402 #[test]
403 #[cfg(feature = "sqlite")]
404 fn test_close_pool_cleans_up_wal_files() {
405 let tmp = tempdir().unwrap();
406 let db_path = tmp.path().join("test.db");
407 let url = format!("sqlite://{}", db_path.display());
408
409 AnyConnection::pre_init(&url).unwrap();
410 let pool = get_pool(&url).unwrap();
411
412 {
414 let mut conn = get_conn(&pool).unwrap();
415 conn.migrate().unwrap();
416 }
417
418 close_pool(&pool);
420 drop(pool);
421
422 let wal_path = tmp.path().join("test.db-wal");
424 let shm_path = tmp.path().join("test.db-shm");
425 assert!(!wal_path.exists(), "WAL file should be cleaned up");
426 assert!(!shm_path.exists(), "SHM file should be cleaned up");
427 }
428
429 #[test]
430 #[cfg(feature = "sqlite")]
431 fn dedupe_migration_merges_max_can_process_into_survivor() {
432 use crate::models::ProcessStatus;
433 use crate::schema::scan_events::dsl::{file_path, process_status, scan_events};
434 use chrono::{NaiveDate, NaiveDateTime, NaiveTime};
435 use diesel::{ExpressionMethods, QueryDsl, RunQueryDsl};
436
437 let tmp = tempdir().unwrap();
438 let db_path = tmp.path().join("test.db");
439 let url = format!("sqlite://{}", db_path.display());
440
441 AnyConnection::pre_init(&url).unwrap();
442 let pool = get_pool(&url).unwrap();
443 let mut conn = get_conn(&pool).unwrap();
444
445 conn.batch_execute(
446 r#"
447 CREATE TABLE scan_events (
448 id TEXT PRIMARY KEY NOT NULL,
449 event_source TEXT NOT NULL,
450 event_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
451 file_path TEXT NOT NULL,
452 file_hash TEXT,
453 process_status TEXT NOT NULL DEFAULT 'pending',
454 found_status TEXT NOT NULL DEFAULT 'not_found',
455 failed_times INTEGER DEFAULT 0 NOT NULL,
456 next_retry_at TIMESTAMP,
457 targets_hit TEXT DEFAULT '' NOT NULL,
458 found_at TIMESTAMP,
459 processed_at TIMESTAMP,
460 created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
461 updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
462 can_process TIMESTAMP NOT NULL DEFAULT "2024-10-14T12:00:00.000"
463 );
464
465 CREATE TABLE __diesel_schema_migrations (
466 version VARCHAR(50) PRIMARY KEY NOT NULL,
467 run_on TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
468 );
469
470 INSERT INTO __diesel_schema_migrations (version) VALUES
471 ('20240829125750'),
472 ('20240905143749'),
473 ('20240906161345'),
474 ('20241012130403'),
475 ('20241205114327'),
476 ('20241205115656'),
477 ('202512300005460000'),
478 ('20260519000001');
479
480 INSERT INTO scan_events (
481 id, event_source, file_path, file_hash, process_status,
482 updated_at, created_at, event_timestamp, can_process
483 ) VALUES
484 (
485 'older-long-wait', 'sonarr', '/media/migrate.mkv', 'sha256:migrate', 'pending',
486 '2026-01-01 00:00:00', '2026-01-01 00:00:00',
487 '2026-01-01 00:00:00', '2026-01-01 03:00:00'
488 ),
489 (
490 'newer-short-wait', 'notify', '/media/migrate.mkv', NULL, 'retry',
491 '2026-01-01 01:00:00', '2026-01-01 01:00:00',
492 '2026-01-01 01:00:00', '2026-01-01 02:00:00'
493 );
494 "#,
495 )
496 .unwrap();
497
498 conn.migrate().unwrap();
499
500 let pending: String = ProcessStatus::Pending.into();
501 let retry: String = ProcessStatus::Retry.into();
502 let rows = scan_events
503 .filter(file_path.eq("/media/migrate.mkv"))
504 .filter(process_status.eq_any([pending, retry]))
505 .load::<ScanEvent>(&mut conn)
506 .unwrap();
507
508 assert_eq!(rows.len(), 1, "migration should leave one non-terminal row");
509 assert_eq!(rows[0].id, "newer-short-wait", "newest row should survive");
510 assert_eq!(
511 rows[0].can_process,
512 NaiveDateTime::new(
513 NaiveDate::from_ymd_opt(2026, 1, 1).unwrap(),
514 NaiveTime::from_hms_opt(3, 0, 0).unwrap(),
515 ),
516 "survivor should inherit the duplicate group's longest wait"
517 );
518 assert_eq!(
519 rows[0].file_hash,
520 Some("sha256:migrate".to_string()),
521 "survivor should inherit a duplicate's hash when it has none"
522 );
523 }
524}