autopulse_database/
conn.rs

1use crate::models::{NewScanEvent, ScanEvent};
2use anyhow::Context;
3use autopulse_utils::sify;
4use diesel::connection::SimpleConnection;
5use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
6use diesel::{Connection, QueryResult, RunQueryDsl};
7use diesel::{SaveChangesDsl, SelectableHelper};
8use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
9use std::os::unix::fs::PermissionsExt;
10use std::path::PathBuf;
11use tracing::info;
12
13#[doc(hidden)]
14#[cfg(feature = "postgres")]
15const POSTGRES_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgres");
16
17#[doc(hidden)]
18#[cfg(feature = "sqlite")]
19const SQLITE_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/sqlite");
20
21/// Represents a connection to either a `PostgreSQL` or `SQLite` database.
22#[derive(diesel::MultiConnection)]
23pub enum AnyConnection {
24    /// A connection to a `PostgreSQL` database.
25    ///
26    /// This is used when the `database_url` is a `PostgreSQL` URL.
27    ///
28    /// # Example
29    ///
30    /// ```md
31    /// postgres://user:password@localhost:5432/database
32    /// ```
33    #[cfg(feature = "postgres")]
34    Postgresql(diesel::PgConnection),
35    // Mysql(diesel::MysqlConnection),
36    /// A connection to a `SQLite` database.
37    ///
38    /// This is used when the `database_url` is a `SQLite` URL.
39    ///
40    /// Note: The directory where the database is stored will also be populated with a WAL file and a journal file.
41    ///
42    /// # Example
43    ///
44    /// ```bash
45    /// # Relative path
46    /// sqlite://database.db
47    /// sqlite://data/database.db
48    ///
49    /// # Absolute path
50    /// sqlite:///data/database.db
51    ///
52    /// # In-memory database
53    /// sqlite://:memory: # In-memory database
54    /// ```
55    #[cfg(feature = "sqlite")]
56    Sqlite(diesel::SqliteConnection),
57}
58
59#[doc(hidden)]
60#[derive(Debug, Default)]
61pub struct AcquireHook {
62    pub setup: bool,
63}
64
65impl diesel::r2d2::CustomizeConnection<AnyConnection, diesel::r2d2::Error> for AcquireHook {
66    fn on_acquire(&self, conn: &mut AnyConnection) -> Result<(), diesel::r2d2::Error> {
67        (|| {
68            match conn {
69                #[cfg(feature = "sqlite")]
70                AnyConnection::Sqlite(ref mut conn) => {
71                    conn.batch_execute("PRAGMA busy_timeout = 5000")?;
72                    conn.batch_execute("PRAGMA synchronous = NORMAL;")?;
73                    conn.batch_execute("PRAGMA wal_autocheckpoint = 1000;")?;
74                    conn.batch_execute("PRAGMA foreign_keys = ON;")?;
75
76                    if self.setup {
77                        conn.batch_execute("PRAGMA journal_mode = WAL;")?;
78                        conn.batch_execute("VACUUM")?;
79                    }
80                }
81                #[cfg(feature = "postgres")]
82                AnyConnection::Postgresql(ref mut conn) => {
83                    if self.setup {
84                        conn.batch_execute("VACUUM ANALYZE")?;
85                    }
86                }
87            }
88            Ok(())
89        })()
90        .map_err(diesel::r2d2::Error::QueryError)
91    }
92}
93
94impl AnyConnection {
95    pub fn pre_init(database_url: &str) -> anyhow::Result<()> {
96        if database_url.starts_with("sqlite://") && !database_url.contains(":memory:") {
97            let path = database_url.split("sqlite://").collect::<Vec<&str>>()[1];
98            let path = PathBuf::from(path);
99            let parent = path.parent().unwrap();
100
101            if !std::path::Path::new(&path).exists() {
102                std::fs::create_dir_all(parent).with_context(|| {
103                    format!("Failed to create database directory: {}", parent.display())
104                })?;
105            }
106
107            if path.file_name().map(|x| x.to_str()) != Some(path.to_str()) {
108                std::fs::set_permissions(parent, std::fs::Permissions::from_mode(0o777))
109                    .with_context(|| {
110                        format!(
111                            "Failed to set permissions on database directory: {}",
112                            parent.display()
113                        )
114                    })?;
115            }
116        }
117
118        Ok(())
119    }
120
121    pub fn migrate(&mut self) -> anyhow::Result<()> {
122        let migrations_applied = match self {
123            #[cfg(feature = "postgres")]
124            Self::Postgresql(conn) => conn.run_pending_migrations(POSTGRES_MIGRATIONS),
125            #[cfg(feature = "sqlite")]
126            Self::Sqlite(conn) => conn.run_pending_migrations(SQLITE_MIGRATIONS),
127        }
128        .expect("Could not run migrations");
129
130        if !migrations_applied.is_empty() {
131            info!(
132                "Applied {} migration{}",
133                migrations_applied.len(),
134                sify(&migrations_applied)
135            );
136        }
137
138        Ok(())
139    }
140
141    pub fn save_changes(&mut self, ev: &mut ScanEvent) -> anyhow::Result<ScanEvent> {
142        let ev = match self {
143            #[cfg(feature = "postgres")]
144            Self::Postgresql(conn) => ev.save_changes::<ScanEvent>(conn),
145            // #[cfg(feature = "mysql")]
146            // AnyConnection::Mysql(conn) => ev.save_changes::<ScanEvent>(conn),
147            #[cfg(feature = "sqlite")]
148            Self::Sqlite(conn) => ev.save_changes::<ScanEvent>(conn),
149        }?;
150
151        Ok(ev)
152    }
153
154    pub fn insert_and_return(&mut self, ev: &NewScanEvent) -> anyhow::Result<ScanEvent> {
155        match self {
156            #[cfg(feature = "postgres")]
157            Self::Postgresql(conn) => diesel::insert_into(crate::schema::scan_events::table)
158                .values(ev)
159                .returning(ScanEvent::as_returning())
160                .get_result::<ScanEvent>(conn)
161                .map_err(Into::into),
162            #[cfg(feature = "sqlite")]
163            Self::Sqlite(conn) => diesel::insert_into(crate::schema::scan_events::table)
164                .values(ev)
165                .returning(ScanEvent::as_returning())
166                .get_result::<ScanEvent>(conn)
167                .map_err(Into::into),
168        }
169    }
170}
171
172#[doc(hidden)]
173pub type DbPool = Pool<ConnectionManager<AnyConnection>>;
174
175#[doc(hidden)]
176pub fn get_conn(
177    pool: &Pool<ConnectionManager<AnyConnection>>,
178) -> anyhow::Result<PooledConnection<ConnectionManager<AnyConnection>>> {
179    pool.get().context("Failed to get connection from pool")
180}
181
182#[doc(hidden)]
183pub fn get_pool(database_url: &String) -> anyhow::Result<Pool<ConnectionManager<AnyConnection>>> {
184    let manager = ConnectionManager::<AnyConnection>::new(database_url);
185
186    let pool = Pool::builder()
187        .max_size(1)
188        .connection_customizer(Box::new(AcquireHook { setup: true }))
189        .build(manager)
190        .context("Failed to create pool");
191
192    drop(pool);
193
194    let manager = ConnectionManager::<AnyConnection>::new(database_url);
195
196    Pool::builder()
197        .connection_customizer(Box::new(AcquireHook::default()))
198        .build(manager)
199        .context("Failed to create pool")
200}