autopulse_database/
conn.rs

1use crate::models::{NewScanEvent, ScanEvent};
2use anyhow::Context;
3use autopulse_utils::sify;
4use diesel::connection::SimpleConnection;
5use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
6use diesel::{Connection, RunQueryDsl};
7use diesel::{SaveChangesDsl, SelectableHelper};
8use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
9use serde::Deserialize;
10use std::path::PathBuf;
11use tracing::info;
12
13#[doc(hidden)]
14#[cfg(feature = "postgres")]
15const POSTGRES_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgres");
16
17#[doc(hidden)]
18#[cfg(feature = "sqlite")]
19const SQLITE_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/sqlite");
20
21#[derive(Deserialize, Debug)]
22#[serde(rename_all = "lowercase")]
23#[derive(Default)]
24pub enum DatabaseType {
25    #[cfg(feature = "sqlite")]
26    #[cfg_attr(feature = "sqlite", default)]
27    Sqlite,
28    #[cfg(feature = "postgres")]
29    #[cfg_attr(not(feature = "sqlite"), default)]
30    Postgres,
31}
32
33impl DatabaseType {
34    pub fn default_url(&self) -> String {
35        match self {
36            #[cfg(feature = "sqlite")]
37            Self::Sqlite => "sqlite://data/autopulse.db".to_string(),
38            #[cfg(feature = "postgres")]
39            Self::Postgres => "postgres://autopulse:autopulse@localhost:5432/autopulse".to_string(),
40        }
41    }
42}
43
44/// Represents a connection to either a `PostgreSQL` or `SQLite` database.
45#[derive(diesel::MultiConnection)]
46pub enum AnyConnection {
47    /// A connection to a `PostgreSQL` database.
48    ///
49    /// This is used when the `database_url` is a `PostgreSQL` URL.
50    ///
51    /// # Example
52    ///
53    /// ```md
54    /// postgres://user:password@localhost:5432/database
55    /// ```
56    #[cfg(feature = "postgres")]
57    Postgresql(diesel::PgConnection),
58    // Mysql(diesel::MysqlConnection),
59    /// A connection to a `SQLite` database.
60    ///
61    /// This is used when the `database_url` is a `SQLite` URL.
62    ///
63    /// Note: The directory where the database is stored will also be populated with a WAL file and a journal file.
64    ///
65    /// # Example
66    ///
67    /// ```bash
68    /// # Relative path
69    /// sqlite://database.db
70    /// sqlite://data/database.db
71    ///
72    /// # Absolute path
73    /// sqlite:///data/database.db
74    ///
75    /// # In-memory database
76    /// sqlite://:memory: # In-memory database
77    /// ```
78    #[cfg(feature = "sqlite")]
79    Sqlite(diesel::SqliteConnection),
80}
81
82#[doc(hidden)]
83#[derive(Debug, Default)]
84pub struct AcquireHook {
85    pub setup: bool,
86}
87
88impl diesel::r2d2::CustomizeConnection<AnyConnection, diesel::r2d2::Error> for AcquireHook {
89    fn on_acquire(&self, conn: &mut AnyConnection) -> Result<(), diesel::r2d2::Error> {
90        (|| {
91            match conn {
92                #[cfg(feature = "sqlite")]
93                AnyConnection::Sqlite(ref mut conn) => {
94                    conn.batch_execute("PRAGMA busy_timeout = 5000")?;
95                    conn.batch_execute("PRAGMA synchronous = NORMAL;")?;
96                    conn.batch_execute("PRAGMA wal_autocheckpoint = 1000;")?;
97                    conn.batch_execute("PRAGMA foreign_keys = ON;")?;
98
99                    if self.setup {
100                        conn.batch_execute("PRAGMA journal_mode = WAL;")?;
101                        conn.batch_execute("VACUUM")?;
102                    }
103                }
104                #[cfg(feature = "postgres")]
105                AnyConnection::Postgresql(ref mut conn) => {
106                    if self.setup {
107                        conn.batch_execute("VACUUM ANALYZE")?;
108                    }
109                }
110            }
111            Ok(())
112        })()
113        .map_err(diesel::r2d2::Error::QueryError)
114    }
115}
116
117impl AnyConnection {
118    pub fn pre_init(database_url: &str) -> anyhow::Result<()> {
119        if database_url.starts_with("sqlite://") && !database_url.contains(":memory:") {
120            let path = database_url.split("sqlite://").collect::<Vec<&str>>()[1];
121            let path = PathBuf::from(path);
122            let parent = path.parent().unwrap();
123
124            if !std::path::Path::new(&path).exists() {
125                std::fs::create_dir_all(parent).with_context(|| {
126                    format!("falsed to create database directory: {}", parent.display())
127                })?;
128            }
129
130            #[cfg(unix)]
131            if path.file_name().map(|x| x.to_str()) != Some(path.to_str()) {
132                use std::os::unix::fs::PermissionsExt;
133
134                std::fs::set_permissions(parent, std::fs::Permissions::from_mode(0o777))
135                    .with_context(|| {
136                        format!(
137                            "falsed to set permissions on database directory: {}",
138                            parent.display()
139                        )
140                    })?;
141            }
142        }
143
144        Ok(())
145    }
146
147    pub fn migrate(&mut self) -> anyhow::Result<()> {
148        let migrations_applied = match self {
149            #[cfg(feature = "postgres")]
150            Self::Postgresql(conn) => conn.run_pending_migrations(POSTGRES_MIGRATIONS),
151            #[cfg(feature = "sqlite")]
152            Self::Sqlite(conn) => conn.run_pending_migrations(SQLITE_MIGRATIONS),
153        }
154        .expect("Could not run migrations");
155
156        if !migrations_applied.is_empty() {
157            info!(
158                "Applied {} migration{}",
159                migrations_applied.len(),
160                sify(&migrations_applied)
161            );
162        }
163
164        Ok(())
165    }
166
167    pub fn save_changes(&mut self, ev: &mut ScanEvent) -> anyhow::Result<ScanEvent> {
168        let ev = match self {
169            #[cfg(feature = "postgres")]
170            Self::Postgresql(conn) => ev.save_changes::<ScanEvent>(conn),
171            // #[cfg(feature = "mysql")]
172            // AnyConnection::Mysql(conn) => ev.save_changes::<ScanEvent>(conn),
173            #[cfg(feature = "sqlite")]
174            Self::Sqlite(conn) => ev.save_changes::<ScanEvent>(conn),
175        }?;
176
177        Ok(ev)
178    }
179
180    pub fn insert_and_return(&mut self, ev: &NewScanEvent) -> anyhow::Result<ScanEvent> {
181        match self {
182            #[cfg(feature = "postgres")]
183            Self::Postgresql(conn) => diesel::insert_into(crate::schema::scan_events::table)
184                .values(ev)
185                .returning(ScanEvent::as_returning())
186                .get_result::<ScanEvent>(conn)
187                .map_err(Into::into),
188            #[cfg(feature = "sqlite")]
189            Self::Sqlite(conn) => diesel::insert_into(crate::schema::scan_events::table)
190                .values(ev)
191                .returning(ScanEvent::as_returning())
192                .get_result::<ScanEvent>(conn)
193                .map_err(Into::into),
194        }
195    }
196}
197
198#[doc(hidden)]
199pub type DbPool = Pool<ConnectionManager<AnyConnection>>;
200
201#[doc(hidden)]
202pub fn get_conn(
203    pool: &Pool<ConnectionManager<AnyConnection>>,
204) -> anyhow::Result<PooledConnection<ConnectionManager<AnyConnection>>> {
205    pool.get().context("failed to get connection from pool")
206}
207
208#[doc(hidden)]
209pub fn get_pool(database_url: &String) -> anyhow::Result<Pool<ConnectionManager<AnyConnection>>> {
210    let manager = ConnectionManager::<AnyConnection>::new(database_url);
211
212    let pool = Pool::builder()
213        .max_size(1)
214        .connection_customizer(Box::new(AcquireHook { setup: true }))
215        .build(manager)
216        .context("failed to create pool");
217
218    drop(pool);
219
220    let manager = ConnectionManager::<AnyConnection>::new(database_url);
221
222    Pool::builder()
223        .connection_customizer(Box::new(AcquireHook::default()))
224        .build(manager)
225        .context("failed to create pool")
226}