autopulse_database/
conn.rs1use crate::models::{NewScanEvent, ScanEvent};
2use anyhow::Context;
3use autopulse_utils::sify;
4use diesel::connection::SimpleConnection;
5use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
6use diesel::{Connection, RunQueryDsl};
7use diesel::{SaveChangesDsl, SelectableHelper};
8use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
9use serde::Deserialize;
10use std::path::PathBuf;
11use tracing::info;
12
13#[doc(hidden)]
14#[cfg(feature = "postgres")]
15const POSTGRES_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgres");
16
17#[doc(hidden)]
18#[cfg(feature = "sqlite")]
19const SQLITE_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/sqlite");
20
21#[derive(Deserialize, Debug)]
22#[serde(rename_all = "lowercase")]
23#[derive(Default)]
24pub enum DatabaseType {
25 #[cfg(feature = "sqlite")]
26 #[cfg_attr(feature = "sqlite", default)]
27 Sqlite,
28 #[cfg(feature = "postgres")]
29 #[cfg_attr(not(feature = "sqlite"), default)]
30 Postgres,
31}
32
33impl DatabaseType {
34 pub fn default_url(&self) -> String {
35 match self {
36 #[cfg(feature = "sqlite")]
37 Self::Sqlite => "sqlite://data/autopulse.db".to_string(),
38 #[cfg(feature = "postgres")]
39 Self::Postgres => "postgres://autopulse:autopulse@localhost:5432/autopulse".to_string(),
40 }
41 }
42}
43
44#[derive(diesel::MultiConnection)]
46pub enum AnyConnection {
47 #[cfg(feature = "postgres")]
57 Postgresql(diesel::PgConnection),
58 #[cfg(feature = "sqlite")]
79 Sqlite(diesel::SqliteConnection),
80}
81
82#[doc(hidden)]
83#[derive(Debug, Default)]
84pub struct AcquireHook {
85 pub setup: bool,
86}
87
88impl diesel::r2d2::CustomizeConnection<AnyConnection, diesel::r2d2::Error> for AcquireHook {
89 fn on_acquire(&self, conn: &mut AnyConnection) -> Result<(), diesel::r2d2::Error> {
90 (|| {
91 match conn {
92 #[cfg(feature = "sqlite")]
93 AnyConnection::Sqlite(ref mut conn) => {
94 conn.batch_execute("PRAGMA busy_timeout = 5000")?;
95 conn.batch_execute("PRAGMA synchronous = NORMAL;")?;
96 conn.batch_execute("PRAGMA wal_autocheckpoint = 1000;")?;
97 conn.batch_execute("PRAGMA foreign_keys = ON;")?;
98
99 if self.setup {
100 conn.batch_execute("PRAGMA journal_mode = WAL;")?;
101 conn.batch_execute("VACUUM")?;
102 }
103 }
104 #[cfg(feature = "postgres")]
105 AnyConnection::Postgresql(ref mut conn) => {
106 if self.setup {
107 conn.batch_execute("VACUUM ANALYZE")?;
108 }
109 }
110 }
111 Ok(())
112 })()
113 .map_err(diesel::r2d2::Error::QueryError)
114 }
115}
116
117impl AnyConnection {
118 pub fn pre_init(database_url: &str) -> anyhow::Result<()> {
119 if database_url.starts_with("sqlite://") && !database_url.contains(":memory:") {
120 let path = database_url.split("sqlite://").collect::<Vec<&str>>()[1];
121 let path = PathBuf::from(path);
122 let parent = path.parent().unwrap();
123
124 if !std::path::Path::new(&path).exists() {
125 std::fs::create_dir_all(parent).with_context(|| {
126 format!("falsed to create database directory: {}", parent.display())
127 })?;
128 }
129
130 #[cfg(unix)]
131 if path.file_name().map(|x| x.to_str()) != Some(path.to_str()) {
132 use std::os::unix::fs::PermissionsExt;
133
134 std::fs::set_permissions(parent, std::fs::Permissions::from_mode(0o777))
135 .with_context(|| {
136 format!(
137 "falsed to set permissions on database directory: {}",
138 parent.display()
139 )
140 })?;
141 }
142 }
143
144 Ok(())
145 }
146
147 pub fn migrate(&mut self) -> anyhow::Result<()> {
148 let migrations_applied = match self {
149 #[cfg(feature = "postgres")]
150 Self::Postgresql(conn) => conn.run_pending_migrations(POSTGRES_MIGRATIONS),
151 #[cfg(feature = "sqlite")]
152 Self::Sqlite(conn) => conn.run_pending_migrations(SQLITE_MIGRATIONS),
153 }
154 .expect("Could not run migrations");
155
156 if !migrations_applied.is_empty() {
157 info!(
158 "Applied {} migration{}",
159 migrations_applied.len(),
160 sify(&migrations_applied)
161 );
162 }
163
164 Ok(())
165 }
166
167 pub fn save_changes(&mut self, ev: &mut ScanEvent) -> anyhow::Result<ScanEvent> {
168 let ev = match self {
169 #[cfg(feature = "postgres")]
170 Self::Postgresql(conn) => ev.save_changes::<ScanEvent>(conn),
171 #[cfg(feature = "sqlite")]
174 Self::Sqlite(conn) => ev.save_changes::<ScanEvent>(conn),
175 }?;
176
177 Ok(ev)
178 }
179
180 pub fn insert_and_return(&mut self, ev: &NewScanEvent) -> anyhow::Result<ScanEvent> {
181 match self {
182 #[cfg(feature = "postgres")]
183 Self::Postgresql(conn) => diesel::insert_into(crate::schema::scan_events::table)
184 .values(ev)
185 .returning(ScanEvent::as_returning())
186 .get_result::<ScanEvent>(conn)
187 .map_err(Into::into),
188 #[cfg(feature = "sqlite")]
189 Self::Sqlite(conn) => diesel::insert_into(crate::schema::scan_events::table)
190 .values(ev)
191 .returning(ScanEvent::as_returning())
192 .get_result::<ScanEvent>(conn)
193 .map_err(Into::into),
194 }
195 }
196}
197
198#[doc(hidden)]
199pub type DbPool = Pool<ConnectionManager<AnyConnection>>;
200
201#[doc(hidden)]
202pub fn get_conn(
203 pool: &Pool<ConnectionManager<AnyConnection>>,
204) -> anyhow::Result<PooledConnection<ConnectionManager<AnyConnection>>> {
205 pool.get().context("failed to get connection from pool")
206}
207
208#[doc(hidden)]
209pub fn get_pool(database_url: &String) -> anyhow::Result<Pool<ConnectionManager<AnyConnection>>> {
210 let manager = ConnectionManager::<AnyConnection>::new(database_url);
211
212 let pool = Pool::builder()
213 .max_size(1)
214 .connection_customizer(Box::new(AcquireHook { setup: true }))
215 .build(manager)
216 .context("failed to create pool");
217
218 drop(pool);
219
220 let manager = ConnectionManager::<AnyConnection>::new(database_url);
221
222 Pool::builder()
223 .connection_customizer(Box::new(AcquireHook::default()))
224 .build(manager)
225 .context("failed to create pool")
226}