autopulse_database/
conn.rs1use crate::models::{NewScanEvent, ScanEvent};
2use anyhow::Context;
3use autopulse_utils::sify;
4use diesel::connection::SimpleConnection;
5use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
6use diesel::{Connection, RunQueryDsl};
7use diesel::{SaveChangesDsl, SelectableHelper};
8use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
9use serde::Deserialize;
10use std::path::PathBuf;
11use tracing::info;
12
13#[doc(hidden)]
14#[cfg(feature = "postgres")]
15const POSTGRES_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgres");
16
17#[doc(hidden)]
18#[cfg(feature = "sqlite")]
19const SQLITE_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/sqlite");
20
21#[derive(Deserialize, Debug)]
22#[serde(rename_all = "lowercase")]
23pub enum DatabaseType {
24 #[cfg(feature = "sqlite")]
25 Sqlite,
26 #[cfg(feature = "postgres")]
27 Postgres,
28}
29
30impl Default for DatabaseType {
31 fn default() -> Self {
32 #[cfg(feature = "sqlite")]
33 {
34 Self::Sqlite
35 }
36 #[cfg(all(not(feature = "sqlite"), feature = "postgres"))]
37 {
38 Self::Postgres
39 }
40 }
41}
42
43impl DatabaseType {
44 pub fn default_url(&self) -> String {
45 match self {
46 #[cfg(feature = "sqlite")]
47 Self::Sqlite => "sqlite://data/autopulse.db".to_string(),
48 #[cfg(feature = "postgres")]
49 Self::Postgres => "postgres://autopulse:autopulse@localhost:5432/autopulse".to_string(),
50 }
51 }
52}
53
54#[derive(diesel::MultiConnection)]
56pub enum AnyConnection {
57 #[cfg(feature = "postgres")]
67 Postgresql(diesel::PgConnection),
68 #[cfg(feature = "sqlite")]
89 Sqlite(diesel::SqliteConnection),
90}
91
92#[doc(hidden)]
93#[derive(Debug, Default)]
94pub struct AcquireHook {
95 pub setup: bool,
96}
97
98impl diesel::r2d2::CustomizeConnection<AnyConnection, diesel::r2d2::Error> for AcquireHook {
99 fn on_acquire(&self, conn: &mut AnyConnection) -> Result<(), diesel::r2d2::Error> {
100 (|| {
101 match conn {
102 #[cfg(feature = "sqlite")]
103 AnyConnection::Sqlite(ref mut conn) => {
104 conn.batch_execute("PRAGMA busy_timeout = 5000")?;
105 conn.batch_execute("PRAGMA synchronous = NORMAL;")?;
106 conn.batch_execute("PRAGMA wal_autocheckpoint = 1000;")?;
107 conn.batch_execute("PRAGMA foreign_keys = ON;")?;
108
109 if self.setup {
110 conn.batch_execute("PRAGMA journal_mode = WAL;")?;
111 conn.batch_execute("VACUUM")?;
112 }
113 }
114 #[cfg(feature = "postgres")]
115 AnyConnection::Postgresql(ref mut conn) => {
116 if self.setup {
117 conn.batch_execute("VACUUM ANALYZE")?;
118 }
119 }
120 }
121 Ok(())
122 })()
123 .map_err(diesel::r2d2::Error::QueryError)
124 }
125}
126
127impl AnyConnection {
128 pub fn pre_init(database_url: &str) -> anyhow::Result<()> {
129 if database_url.starts_with("sqlite://") && !database_url.contains(":memory:") {
130 let path = database_url.split("sqlite://").collect::<Vec<&str>>()[1];
131 let path = PathBuf::from(path);
132 let parent = path.parent().unwrap();
133
134 if !std::path::Path::new(&path).exists() {
135 std::fs::create_dir_all(parent).with_context(|| {
136 format!("Failed to create database directory: {}", parent.display())
137 })?;
138 }
139
140 #[cfg(unix)]
141 if path.file_name().map(|x| x.to_str()) != Some(path.to_str()) {
142 use std::os::unix::fs::PermissionsExt;
143
144 std::fs::set_permissions(parent, std::fs::Permissions::from_mode(0o777))
145 .with_context(|| {
146 format!(
147 "Failed to set permissions on database directory: {}",
148 parent.display()
149 )
150 })?;
151 }
152 }
153
154 Ok(())
155 }
156
157 pub fn migrate(&mut self) -> anyhow::Result<()> {
158 let migrations_applied = match self {
159 #[cfg(feature = "postgres")]
160 Self::Postgresql(conn) => conn.run_pending_migrations(POSTGRES_MIGRATIONS),
161 #[cfg(feature = "sqlite")]
162 Self::Sqlite(conn) => conn.run_pending_migrations(SQLITE_MIGRATIONS),
163 }
164 .expect("Could not run migrations");
165
166 if !migrations_applied.is_empty() {
167 info!(
168 "Applied {} migration{}",
169 migrations_applied.len(),
170 sify(&migrations_applied)
171 );
172 }
173
174 Ok(())
175 }
176
177 pub fn save_changes(&mut self, ev: &mut ScanEvent) -> anyhow::Result<ScanEvent> {
178 let ev = match self {
179 #[cfg(feature = "postgres")]
180 Self::Postgresql(conn) => ev.save_changes::<ScanEvent>(conn),
181 #[cfg(feature = "sqlite")]
184 Self::Sqlite(conn) => ev.save_changes::<ScanEvent>(conn),
185 }?;
186
187 Ok(ev)
188 }
189
190 pub fn insert_and_return(&mut self, ev: &NewScanEvent) -> anyhow::Result<ScanEvent> {
191 match self {
192 #[cfg(feature = "postgres")]
193 Self::Postgresql(conn) => diesel::insert_into(crate::schema::scan_events::table)
194 .values(ev)
195 .returning(ScanEvent::as_returning())
196 .get_result::<ScanEvent>(conn)
197 .map_err(Into::into),
198 #[cfg(feature = "sqlite")]
199 Self::Sqlite(conn) => diesel::insert_into(crate::schema::scan_events::table)
200 .values(ev)
201 .returning(ScanEvent::as_returning())
202 .get_result::<ScanEvent>(conn)
203 .map_err(Into::into),
204 }
205 }
206}
207
208#[doc(hidden)]
209pub type DbPool = Pool<ConnectionManager<AnyConnection>>;
210
211#[doc(hidden)]
212pub fn get_conn(
213 pool: &Pool<ConnectionManager<AnyConnection>>,
214) -> anyhow::Result<PooledConnection<ConnectionManager<AnyConnection>>> {
215 pool.get().context("Failed to get connection from pool")
216}
217
218#[doc(hidden)]
219pub fn get_pool(database_url: &String) -> anyhow::Result<Pool<ConnectionManager<AnyConnection>>> {
220 let manager = ConnectionManager::<AnyConnection>::new(database_url);
221
222 let pool = Pool::builder()
223 .max_size(1)
224 .connection_customizer(Box::new(AcquireHook { setup: true }))
225 .build(manager)
226 .context("Failed to create pool");
227
228 drop(pool);
229
230 let manager = ConnectionManager::<AnyConnection>::new(database_url);
231
232 Pool::builder()
233 .connection_customizer(Box::new(AcquireHook::default()))
234 .build(manager)
235 .context("Failed to create pool")
236}