autopulse_database/
conn.rs1use crate::models::{NewScanEvent, ScanEvent};
2use anyhow::Context;
3use autopulse_utils::sify;
4use diesel::connection::SimpleConnection;
5use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
6use diesel::{Connection, RunQueryDsl};
7use diesel::{SaveChangesDsl, SelectableHelper};
8use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
9use serde::Deserialize;
10use std::path::PathBuf;
11use tracing::info;
12
13#[doc(hidden)]
14#[cfg(feature = "postgres")]
15const POSTGRES_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgres");
16
17#[doc(hidden)]
18#[cfg(feature = "sqlite")]
19const SQLITE_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/sqlite");
20
21#[derive(Deserialize, Debug)]
22#[serde(rename_all = "lowercase")]
23#[derive(Default)]
24pub enum DatabaseType {
25 #[cfg(feature = "sqlite")]
26 #[cfg_attr(feature = "sqlite", default)]
27 Sqlite,
28 #[cfg(feature = "postgres")]
29 #[cfg_attr(not(feature = "sqlite"), default)]
30 Postgres,
31}
32
33impl DatabaseType {
34 pub fn default_url(&self) -> String {
35 match self {
36 #[cfg(feature = "sqlite")]
37 Self::Sqlite => "sqlite://data/autopulse.db".to_string(),
38 #[cfg(feature = "postgres")]
39 Self::Postgres => "postgres://autopulse:autopulse@localhost:5432/autopulse".to_string(),
40 }
41 }
42}
43
44#[derive(diesel::MultiConnection)]
46pub enum AnyConnection {
47 #[cfg(feature = "postgres")]
57 Postgresql(diesel::PgConnection),
58 #[cfg(feature = "sqlite")]
79 Sqlite(diesel::SqliteConnection),
80}
81
82#[doc(hidden)]
83#[derive(Debug, Default)]
84pub struct AcquireHook {
85 pub setup: bool,
86}
87
88impl diesel::r2d2::CustomizeConnection<AnyConnection, diesel::r2d2::Error> for AcquireHook {
89 fn on_acquire(&self, conn: &mut AnyConnection) -> Result<(), diesel::r2d2::Error> {
90 (|| {
91 match conn {
92 #[cfg(feature = "sqlite")]
93 AnyConnection::Sqlite(ref mut conn) => {
94 conn.batch_execute("PRAGMA busy_timeout = 5000")?;
95 conn.batch_execute("PRAGMA synchronous = NORMAL;")?;
96 conn.batch_execute("PRAGMA wal_autocheckpoint = 1000;")?;
97 conn.batch_execute("PRAGMA foreign_keys = ON;")?;
98
99 if self.setup {
100 conn.batch_execute("PRAGMA journal_mode = WAL;")?;
101 conn.batch_execute("VACUUM")?;
102 }
103 }
104 #[cfg(feature = "postgres")]
105 AnyConnection::Postgresql(ref mut conn) => {
106 if self.setup {
107 conn.batch_execute("VACUUM ANALYZE")?;
108 }
109 }
110 }
111 Ok(())
112 })()
113 .map_err(diesel::r2d2::Error::QueryError)
114 }
115}
116
117impl AnyConnection {
118 pub fn pre_init(database_url: &str) -> anyhow::Result<()> {
119 if database_url.starts_with("sqlite://") && !database_url.contains(":memory:") {
120 let path = database_url
121 .strip_prefix("sqlite://")
122 .expect("already checked prefix");
123
124 let path = PathBuf::from(path);
125
126 let Some(parent) = path.parent().filter(|p| !p.as_os_str().is_empty()) else {
127 return Ok(());
128 };
129
130 if !parent.exists() {
132 std::fs::create_dir_all(parent).with_context(|| {
133 format!("failed to create database directory: {}", parent.display())
134 })?;
135 }
136 }
137
138 Ok(())
139 }
140
141 pub fn migrate(&mut self) -> anyhow::Result<()> {
142 let migrations_applied = match self {
143 #[cfg(feature = "postgres")]
144 Self::Postgresql(conn) => conn.run_pending_migrations(POSTGRES_MIGRATIONS),
145 #[cfg(feature = "sqlite")]
146 Self::Sqlite(conn) => conn.run_pending_migrations(SQLITE_MIGRATIONS),
147 }
148 .expect("Could not run migrations");
149
150 if !migrations_applied.is_empty() {
151 info!(
152 "Applied {} migration{}",
153 migrations_applied.len(),
154 sify(&migrations_applied)
155 );
156 }
157
158 Ok(())
159 }
160
161 pub fn close(&mut self) -> anyhow::Result<()> {
162 match self {
163 #[cfg(feature = "postgres")]
164 Self::Postgresql(_) => {}
165 #[cfg(feature = "sqlite")]
166 Self::Sqlite(conn) => {
167 conn.batch_execute("PRAGMA wal_checkpoint(TRUNCATE);")
169 .context("failed to checkpoint WAL")?;
170 }
171 }
172
173 Ok(())
174 }
175
176 pub fn save_changes(&mut self, ev: &mut ScanEvent) -> anyhow::Result<ScanEvent> {
177 let ev = match self {
178 #[cfg(feature = "postgres")]
179 Self::Postgresql(conn) => ev.save_changes::<ScanEvent>(conn),
180 #[cfg(feature = "sqlite")]
183 Self::Sqlite(conn) => ev.save_changes::<ScanEvent>(conn),
184 }?;
185
186 Ok(ev)
187 }
188
189 pub fn insert_and_return(&mut self, ev: &NewScanEvent) -> anyhow::Result<ScanEvent> {
190 match self {
191 #[cfg(feature = "postgres")]
192 Self::Postgresql(conn) => diesel::insert_into(crate::schema::scan_events::table)
193 .values(ev)
194 .returning(ScanEvent::as_returning())
195 .get_result::<ScanEvent>(conn)
196 .map_err(Into::into),
197 #[cfg(feature = "sqlite")]
198 Self::Sqlite(conn) => diesel::insert_into(crate::schema::scan_events::table)
199 .values(ev)
200 .returning(ScanEvent::as_returning())
201 .get_result::<ScanEvent>(conn)
202 .map_err(Into::into),
203 }
204 }
205}
206
207#[doc(hidden)]
208pub type DbPool = Pool<ConnectionManager<AnyConnection>>;
209
210#[doc(hidden)]
211pub fn get_conn(
212 pool: &Pool<ConnectionManager<AnyConnection>>,
213) -> anyhow::Result<PooledConnection<ConnectionManager<AnyConnection>>> {
214 pool.get().context("failed to get connection from pool")
215}
216
217pub fn close_pool(pool: &Pool<ConnectionManager<AnyConnection>>) {
218 if let Ok(mut conn) = pool.get() {
219 let _ = conn.close();
220 }
221}
222
223#[doc(hidden)]
224pub fn get_pool(database_url: &String) -> anyhow::Result<Pool<ConnectionManager<AnyConnection>>> {
225 let manager = ConnectionManager::<AnyConnection>::new(database_url);
226
227 let pool = Pool::builder()
228 .max_size(1)
229 .connection_customizer(Box::new(AcquireHook { setup: true }))
230 .build(manager)
231 .context("failed to create pool");
232
233 drop(pool);
234
235 let manager = ConnectionManager::<AnyConnection>::new(database_url);
236
237 Pool::builder()
238 .connection_customizer(Box::new(AcquireHook::default()))
239 .build(manager)
240 .context("failed to create pool")
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use std::fs;
247 use tempfile::tempdir;
248
249 #[test]
250 fn test_pre_init_memory_db_skipped() {
251 let result = AnyConnection::pre_init("sqlite://:memory:");
252 assert!(result.is_ok());
253 }
254
255 #[test]
256 fn test_pre_init_creates_directory() {
257 let tmp = tempdir().unwrap();
258 let db_path = tmp.path().join("subdir").join("test.db");
259 let url = format!("sqlite://{}", db_path.display());
260
261 let result = AnyConnection::pre_init(&url);
262 assert!(result.is_ok());
263 assert!(db_path.parent().unwrap().exists());
264 }
265
266 #[test]
267 fn test_pre_init_no_parent_directory() {
268 let result = AnyConnection::pre_init("sqlite://test.db");
269 assert!(result.is_ok());
270 }
271
272 #[test]
273 fn test_pre_init_writable_directory_succeeds() {
274 let tmp = tempdir().unwrap();
275 let subdir = tmp.path().join("writable");
276 fs::create_dir(&subdir).unwrap();
277
278 let db_path = subdir.join("test.db");
279 let url = format!("sqlite://{}", db_path.display());
280
281 let result = AnyConnection::pre_init(&url);
282 assert!(result.is_ok());
283 }
284
285 #[test]
286 fn test_pre_init_postgres_skipped() {
287 let result = AnyConnection::pre_init("postgres://localhost/test");
288 assert!(result.is_ok());
289 }
290
291 #[test]
292 #[cfg(feature = "sqlite")]
293 fn test_close_pool_cleans_up_wal_files() {
294 let tmp = tempdir().unwrap();
295 let db_path = tmp.path().join("test.db");
296 let url = format!("sqlite://{}", db_path.display());
297
298 AnyConnection::pre_init(&url).unwrap();
299 let pool = get_pool(&url).unwrap();
300
301 {
303 let mut conn = get_conn(&pool).unwrap();
304 conn.migrate().unwrap();
305 }
306
307 close_pool(&pool);
309 drop(pool);
310
311 let wal_path = tmp.path().join("test.db-wal");
313 let shm_path = tmp.path().join("test.db-shm");
314 assert!(!wal_path.exists(), "WAL file should be cleaned up");
315 assert!(!shm_path.exists(), "SHM file should be cleaned up");
316 }
317}