autopulse_database/
conn.rs

1use crate::models::{NewScanEvent, ScanEvent};
2use anyhow::Context;
3use autopulse_utils::sify;
4use diesel::connection::SimpleConnection;
5use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
6use diesel::{Connection, RunQueryDsl};
7use diesel::{SaveChangesDsl, SelectableHelper};
8use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
9use serde::Deserialize;
10use std::path::PathBuf;
11use tracing::info;
12
13#[doc(hidden)]
14#[cfg(feature = "postgres")]
15const POSTGRES_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgres");
16
17#[doc(hidden)]
18#[cfg(feature = "sqlite")]
19const SQLITE_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/sqlite");
20
21#[derive(Deserialize, Debug)]
22#[serde(rename_all = "lowercase")]
23pub enum DatabaseType {
24    #[cfg(feature = "sqlite")]
25    Sqlite,
26    #[cfg(feature = "postgres")]
27    Postgres,
28}
29
30impl Default for DatabaseType {
31    fn default() -> Self {
32        #[cfg(feature = "sqlite")]
33        {
34            Self::Sqlite
35        }
36        #[cfg(all(not(feature = "sqlite"), feature = "postgres"))]
37        {
38            Self::Postgres
39        }
40    }
41}
42
43impl DatabaseType {
44    pub fn default_url(&self) -> String {
45        match self {
46            #[cfg(feature = "sqlite")]
47            Self::Sqlite => "sqlite://data/autopulse.db".to_string(),
48            #[cfg(feature = "postgres")]
49            Self::Postgres => "postgres://autopulse:autopulse@localhost:5432/autopulse".to_string(),
50        }
51    }
52}
53
54/// Represents a connection to either a `PostgreSQL` or `SQLite` database.
55#[derive(diesel::MultiConnection)]
56pub enum AnyConnection {
57    /// A connection to a `PostgreSQL` database.
58    ///
59    /// This is used when the `database_url` is a `PostgreSQL` URL.
60    ///
61    /// # Example
62    ///
63    /// ```md
64    /// postgres://user:password@localhost:5432/database
65    /// ```
66    #[cfg(feature = "postgres")]
67    Postgresql(diesel::PgConnection),
68    // Mysql(diesel::MysqlConnection),
69    /// A connection to a `SQLite` database.
70    ///
71    /// This is used when the `database_url` is a `SQLite` URL.
72    ///
73    /// Note: The directory where the database is stored will also be populated with a WAL file and a journal file.
74    ///
75    /// # Example
76    ///
77    /// ```bash
78    /// # Relative path
79    /// sqlite://database.db
80    /// sqlite://data/database.db
81    ///
82    /// # Absolute path
83    /// sqlite:///data/database.db
84    ///
85    /// # In-memory database
86    /// sqlite://:memory: # In-memory database
87    /// ```
88    #[cfg(feature = "sqlite")]
89    Sqlite(diesel::SqliteConnection),
90}
91
92#[doc(hidden)]
93#[derive(Debug, Default)]
94pub struct AcquireHook {
95    pub setup: bool,
96}
97
98impl diesel::r2d2::CustomizeConnection<AnyConnection, diesel::r2d2::Error> for AcquireHook {
99    fn on_acquire(&self, conn: &mut AnyConnection) -> Result<(), diesel::r2d2::Error> {
100        (|| {
101            match conn {
102                #[cfg(feature = "sqlite")]
103                AnyConnection::Sqlite(ref mut conn) => {
104                    conn.batch_execute("PRAGMA busy_timeout = 5000")?;
105                    conn.batch_execute("PRAGMA synchronous = NORMAL;")?;
106                    conn.batch_execute("PRAGMA wal_autocheckpoint = 1000;")?;
107                    conn.batch_execute("PRAGMA foreign_keys = ON;")?;
108
109                    if self.setup {
110                        conn.batch_execute("PRAGMA journal_mode = WAL;")?;
111                        conn.batch_execute("VACUUM")?;
112                    }
113                }
114                #[cfg(feature = "postgres")]
115                AnyConnection::Postgresql(ref mut conn) => {
116                    if self.setup {
117                        conn.batch_execute("VACUUM ANALYZE")?;
118                    }
119                }
120            }
121            Ok(())
122        })()
123        .map_err(diesel::r2d2::Error::QueryError)
124    }
125}
126
127impl AnyConnection {
128    pub fn pre_init(database_url: &str) -> anyhow::Result<()> {
129        if database_url.starts_with("sqlite://") && !database_url.contains(":memory:") {
130            let path = database_url.split("sqlite://").collect::<Vec<&str>>()[1];
131            let path = PathBuf::from(path);
132            let parent = path.parent().unwrap();
133
134            if !std::path::Path::new(&path).exists() {
135                std::fs::create_dir_all(parent).with_context(|| {
136                    format!("Failed to create database directory: {}", parent.display())
137                })?;
138            }
139
140            #[cfg(unix)]
141            if path.file_name().map(|x| x.to_str()) != Some(path.to_str()) {
142                use std::os::unix::fs::PermissionsExt;
143
144                std::fs::set_permissions(parent, std::fs::Permissions::from_mode(0o777))
145                    .with_context(|| {
146                        format!(
147                            "Failed to set permissions on database directory: {}",
148                            parent.display()
149                        )
150                    })?;
151            }
152        }
153
154        Ok(())
155    }
156
157    pub fn migrate(&mut self) -> anyhow::Result<()> {
158        let migrations_applied = match self {
159            #[cfg(feature = "postgres")]
160            Self::Postgresql(conn) => conn.run_pending_migrations(POSTGRES_MIGRATIONS),
161            #[cfg(feature = "sqlite")]
162            Self::Sqlite(conn) => conn.run_pending_migrations(SQLITE_MIGRATIONS),
163        }
164        .expect("Could not run migrations");
165
166        if !migrations_applied.is_empty() {
167            info!(
168                "Applied {} migration{}",
169                migrations_applied.len(),
170                sify(&migrations_applied)
171            );
172        }
173
174        Ok(())
175    }
176
177    pub fn save_changes(&mut self, ev: &mut ScanEvent) -> anyhow::Result<ScanEvent> {
178        let ev = match self {
179            #[cfg(feature = "postgres")]
180            Self::Postgresql(conn) => ev.save_changes::<ScanEvent>(conn),
181            // #[cfg(feature = "mysql")]
182            // AnyConnection::Mysql(conn) => ev.save_changes::<ScanEvent>(conn),
183            #[cfg(feature = "sqlite")]
184            Self::Sqlite(conn) => ev.save_changes::<ScanEvent>(conn),
185        }?;
186
187        Ok(ev)
188    }
189
190    pub fn insert_and_return(&mut self, ev: &NewScanEvent) -> anyhow::Result<ScanEvent> {
191        match self {
192            #[cfg(feature = "postgres")]
193            Self::Postgresql(conn) => diesel::insert_into(crate::schema::scan_events::table)
194                .values(ev)
195                .returning(ScanEvent::as_returning())
196                .get_result::<ScanEvent>(conn)
197                .map_err(Into::into),
198            #[cfg(feature = "sqlite")]
199            Self::Sqlite(conn) => diesel::insert_into(crate::schema::scan_events::table)
200                .values(ev)
201                .returning(ScanEvent::as_returning())
202                .get_result::<ScanEvent>(conn)
203                .map_err(Into::into),
204        }
205    }
206}
207
208#[doc(hidden)]
209pub type DbPool = Pool<ConnectionManager<AnyConnection>>;
210
211#[doc(hidden)]
212pub fn get_conn(
213    pool: &Pool<ConnectionManager<AnyConnection>>,
214) -> anyhow::Result<PooledConnection<ConnectionManager<AnyConnection>>> {
215    pool.get().context("Failed to get connection from pool")
216}
217
218#[doc(hidden)]
219pub fn get_pool(database_url: &String) -> anyhow::Result<Pool<ConnectionManager<AnyConnection>>> {
220    let manager = ConnectionManager::<AnyConnection>::new(database_url);
221
222    let pool = Pool::builder()
223        .max_size(1)
224        .connection_customizer(Box::new(AcquireHook { setup: true }))
225        .build(manager)
226        .context("Failed to create pool");
227
228    drop(pool);
229
230    let manager = ConnectionManager::<AnyConnection>::new(database_url);
231
232    Pool::builder()
233        .connection_customizer(Box::new(AcquireHook::default()))
234        .build(manager)
235        .context("Failed to create pool")
236}