autopulse_database/
conn.rs1use crate::models::{NewScanEvent, ScanEvent};
2use anyhow::Context;
3use autopulse_utils::sify;
4use diesel::connection::SimpleConnection;
5use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
6use diesel::{Connection, QueryResult, RunQueryDsl};
7use diesel::{SaveChangesDsl, SelectableHelper};
8use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
9use std::os::unix::fs::PermissionsExt;
10use std::path::PathBuf;
11use tracing::info;
12
13#[doc(hidden)]
14#[cfg(feature = "postgres")]
15const POSTGRES_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/postgres");
16
17#[doc(hidden)]
18#[cfg(feature = "sqlite")]
19const SQLITE_MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/sqlite");
20
21#[derive(diesel::MultiConnection)]
23pub enum AnyConnection {
24 #[cfg(feature = "postgres")]
34 Postgresql(diesel::PgConnection),
35 #[cfg(feature = "sqlite")]
56 Sqlite(diesel::SqliteConnection),
57}
58
59#[doc(hidden)]
60#[derive(Debug, Default)]
61pub struct AcquireHook {
62 pub setup: bool,
63}
64
65impl diesel::r2d2::CustomizeConnection<AnyConnection, diesel::r2d2::Error> for AcquireHook {
66 fn on_acquire(&self, conn: &mut AnyConnection) -> Result<(), diesel::r2d2::Error> {
67 (|| {
68 match conn {
69 #[cfg(feature = "sqlite")]
70 AnyConnection::Sqlite(ref mut conn) => {
71 conn.batch_execute("PRAGMA busy_timeout = 5000")?;
72 conn.batch_execute("PRAGMA synchronous = NORMAL;")?;
73 conn.batch_execute("PRAGMA wal_autocheckpoint = 1000;")?;
74 conn.batch_execute("PRAGMA foreign_keys = ON;")?;
75
76 if self.setup {
77 conn.batch_execute("PRAGMA journal_mode = WAL;")?;
78 conn.batch_execute("VACUUM")?;
79 }
80 }
81 #[cfg(feature = "postgres")]
82 AnyConnection::Postgresql(ref mut conn) => {
83 if self.setup {
84 conn.batch_execute("VACUUM ANALYZE")?;
85 }
86 }
87 }
88 Ok(())
89 })()
90 .map_err(diesel::r2d2::Error::QueryError)
91 }
92}
93
94impl AnyConnection {
95 pub fn pre_init(database_url: &str) -> anyhow::Result<()> {
96 if database_url.starts_with("sqlite://") && !database_url.contains(":memory:") {
97 let path = database_url.split("sqlite://").collect::<Vec<&str>>()[1];
98 let path = PathBuf::from(path);
99 let parent = path.parent().unwrap();
100
101 if !std::path::Path::new(&path).exists() {
102 std::fs::create_dir_all(parent).with_context(|| {
103 format!("Failed to create database directory: {}", parent.display())
104 })?;
105 }
106
107 if path.file_name().map(|x| x.to_str()) != Some(path.to_str()) {
108 std::fs::set_permissions(parent, std::fs::Permissions::from_mode(0o777))
109 .with_context(|| {
110 format!(
111 "Failed to set permissions on database directory: {}",
112 parent.display()
113 )
114 })?;
115 }
116 }
117
118 Ok(())
119 }
120
121 pub fn migrate(&mut self) -> anyhow::Result<()> {
122 let migrations_applied = match self {
123 #[cfg(feature = "postgres")]
124 Self::Postgresql(conn) => conn.run_pending_migrations(POSTGRES_MIGRATIONS),
125 #[cfg(feature = "sqlite")]
126 Self::Sqlite(conn) => conn.run_pending_migrations(SQLITE_MIGRATIONS),
127 }
128 .expect("Could not run migrations");
129
130 if !migrations_applied.is_empty() {
131 info!(
132 "Applied {} migration{}",
133 migrations_applied.len(),
134 sify(&migrations_applied)
135 );
136 }
137
138 Ok(())
139 }
140
141 pub fn save_changes(&mut self, ev: &mut ScanEvent) -> anyhow::Result<ScanEvent> {
142 let ev = match self {
143 #[cfg(feature = "postgres")]
144 Self::Postgresql(conn) => ev.save_changes::<ScanEvent>(conn),
145 #[cfg(feature = "sqlite")]
148 Self::Sqlite(conn) => ev.save_changes::<ScanEvent>(conn),
149 }?;
150
151 Ok(ev)
152 }
153
154 pub fn insert_and_return(&mut self, ev: &NewScanEvent) -> anyhow::Result<ScanEvent> {
155 match self {
156 #[cfg(feature = "postgres")]
157 Self::Postgresql(conn) => diesel::insert_into(crate::schema::scan_events::table)
158 .values(ev)
159 .returning(ScanEvent::as_returning())
160 .get_result::<ScanEvent>(conn)
161 .map_err(Into::into),
162 #[cfg(feature = "sqlite")]
163 Self::Sqlite(conn) => diesel::insert_into(crate::schema::scan_events::table)
164 .values(ev)
165 .returning(ScanEvent::as_returning())
166 .get_result::<ScanEvent>(conn)
167 .map_err(Into::into),
168 }
169 }
170}
171
172#[doc(hidden)]
173pub type DbPool = Pool<ConnectionManager<AnyConnection>>;
174
175#[doc(hidden)]
176pub fn get_conn(
177 pool: &Pool<ConnectionManager<AnyConnection>>,
178) -> anyhow::Result<PooledConnection<ConnectionManager<AnyConnection>>> {
179 pool.get().context("Failed to get connection from pool")
180}
181
182#[doc(hidden)]
183pub fn get_pool(database_url: &String) -> anyhow::Result<Pool<ConnectionManager<AnyConnection>>> {
184 let manager = ConnectionManager::<AnyConnection>::new(database_url);
185
186 let pool = Pool::builder()
187 .max_size(1)
188 .connection_customizer(Box::new(AcquireHook { setup: true }))
189 .build(manager)
190 .context("Failed to create pool");
191
192 drop(pool);
193
194 let manager = ConnectionManager::<AnyConnection>::new(database_url);
195
196 Pool::builder()
197 .connection_customizer(Box::new(AcquireHook::default()))
198 .build(manager)
199 .context("Failed to create pool")
200}