diff --git a/Cargo.toml b/Cargo.toml index d66e502..7464e72 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,10 +10,14 @@ repository = "https://github.com/mrivnak/pond" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +bitcode = { version = "0.6.0", default-features = false, features = ["serde"] } chrono = { version = "0.4.38", features = ["serde"] } -rusqlite = { version = "0.31.0", features = ["bundled"] } +rusqlite = { version = "0.31.0", features = ["blob", "bundled"] } +serde = "1.0.202" [dev-dependencies] +bitcode = { version = "0.6.0", features = ["serde"] } rand = "0.8.5" -uuid = { version = "1.8.0", features = ["v4"] } +serde = { version = "1.0.202", features = ["derive"] } +uuid = { version = "1.8.0", features = ["v4", "serde"] } diff --git a/src/lib.rs b/src/lib.rs index 91dfd31..e460946 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,29 +1,30 @@ use std::hash::{DefaultHasher, Hash, Hasher}; use std::path::PathBuf; -use std::time::Instant; use chrono::{DateTime, Duration, Utc}; use rusqlite::Connection; +use serde::de::DeserializeOwned; +use serde::Serialize; -pub use rusqlite::types::{FromSql, ToSql}; pub use rusqlite::Error; -pub struct Cache { +pub struct Cache { path: PathBuf, ttl: Duration, + data: std::marker::PhantomData, } #[derive(Debug)] -pub struct CacheEntry +struct CacheEntry where - T: ToSql + FromSql, + T: Serialize + DeserializeOwned + Clone, { key: u32, value: T, expiration: DateTime, } -impl Cache { +impl Cache { pub fn new(path: PathBuf) -> Result { Self::with_time_to_live(path, Duration::minutes(10)) } @@ -35,20 +36,23 @@ impl Cache { "CREATE TABLE IF NOT EXISTS items ( id TEXT PRIMARY KEY, expires TEXT NOT NULL, - data TEXT NOT NULL + data BLOB NOT NULL )", - (), // empty list of parameters. + (), )?; db.close().expect("Failed to close database connection"); - Ok(Self { path, ttl }) + Ok(Self { + path, + ttl, + data: std::marker::PhantomData, + }) } - pub fn get(&self, key: K) -> Result, Error> + pub fn get(&self, key: K) -> Result, Error> where K: Hash, - T: ToSql + FromSql, { let db = Connection::open(self.path.as_path())?; @@ -77,12 +81,14 @@ impl Cache { .with_timezone(&Utc) }) .unwrap(); - let data: T = row.get(2).unwrap(); + let data: Vec = row.get(2).unwrap(); drop(rows); drop(stmt); db.close().expect("Failed to close database connection"); + let data: T = bitcode::deserialize(&data).unwrap(); + if expires < Utc::now() { Ok(None) } else { @@ -90,24 +96,16 @@ impl Cache { } } - pub fn store(&self, key: K, value: T) -> Result<(), Error> - where - K: Hash, - T: ToSql + FromSql, - { + pub fn store(&self, key: K, value: T) -> Result<(), Error> { self.store_with_expiration(key, value, Utc::now() + self.ttl) } - pub fn store_with_expiration( + pub fn store_with_expiration( &self, key: K, value: T, expiration: DateTime, - ) -> Result<(), Error> - where - K: Hash, - T: ToSql + FromSql, - { + ) -> Result<(), Error> { let mut hasher = DefaultHasher::new(); let hash = { key.hash(&mut hasher); @@ -127,7 +125,7 @@ impl Cache { ( &value.key.to_string(), &value.expiration.to_rfc3339(), - &value.value, + &bitcode::serialize(&value.value).unwrap(), ), )?; @@ -152,14 +150,22 @@ impl Cache { #[cfg(test)] mod tests { + use serde::Deserialize; + use serde::Serialize; use uuid::Uuid; use super::*; + #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] + struct User { + id: Uuid, + name: String, + } + fn store_manual( path: PathBuf, key: String, - value: String, + value: Vec, expires: DateTime, ) -> Result<(), Error> { let mut hasher = DefaultHasher::new(); @@ -180,7 +186,7 @@ mod tests { Ok(()) } - fn get_manual( + fn get_manual( path: PathBuf, key: String, ) -> Result>, Error> { @@ -212,12 +218,14 @@ mod tests { .with_timezone(&Utc) }) .unwrap(); - let data: T = row.get(2).unwrap(); + let data: Vec = row.get(2).unwrap(); drop(rows); drop(stmt); db.close().expect("Failed to close database connection"); + let data: T = bitcode::deserialize(&data).unwrap(); + Ok(Some(CacheEntry { key: hash, value: data, @@ -232,7 +240,7 @@ mod tests { Uuid::new_v4(), rand::random::() )); - let cache = Cache::new(filename.clone()).unwrap(); + let cache: Cache = Cache::new(filename.clone()).unwrap(); assert_eq!(cache.path, filename); assert_eq!(cache.ttl, Duration::minutes(10)); } @@ -244,8 +252,8 @@ mod tests { Uuid::new_v4(), rand::random::() )); - let _ = Cache::new(filename.clone()).unwrap(); - let _ = Cache::new(filename).unwrap(); + let _: Cache = Cache::new(filename.clone()).unwrap(); + let _: Cache = Cache::new(filename).unwrap(); } #[test] @@ -255,7 +263,8 @@ mod tests { Uuid::new_v4(), rand::random::() )); - let cache = Cache::with_time_to_live(filename.clone(), Duration::minutes(5)).unwrap(); + let cache: Cache = + Cache::with_time_to_live(filename.clone(), Duration::minutes(5)).unwrap(); assert_eq!(cache.path, filename); assert_eq!(cache.ttl, Duration::minutes(5)); } @@ -279,6 +288,28 @@ mod tests { assert_eq!(result, Some(value)); } + #[test] + fn test_store_get_struct() { + let filename = std::env::temp_dir().join(format!( + "pond-test-{}-{}.sqlite", + Uuid::new_v4(), + rand::random::() + )); + + let cache = Cache::new(filename).unwrap(); + + let key = Uuid::new_v4(); + let value = User { + id: Uuid::new_v4(), + name: String::from("Alice"), + }; + + cache.store(key, value.clone()).unwrap(); + let result: Option<_> = cache.get(key).unwrap(); + + assert_eq!(result, Some(value)); + } + #[test] fn test_store_existing() { let filename = std::env::temp_dir().join(format!( @@ -317,7 +348,7 @@ mod tests { store_manual( filename, key.to_string(), - value, + bitcode::serialize(&value).unwrap(), Utc::now() - Duration::minutes(5), ) .unwrap(); @@ -345,7 +376,8 @@ mod tests { #[test] fn test_invalid_path() { - let cache = Cache::new(PathBuf::from("invalid/path/db.sqlite")); + let cache: Result, Error> = + Cache::new(PathBuf::from("invalid/path/db.sqlite")); assert!(cache.is_err()); } @@ -358,7 +390,8 @@ mod tests { rand::random::() )); - let cache = Cache::with_time_to_live(filename.clone(), Duration::minutes(5)).unwrap(); + let cache: Cache = + Cache::with_time_to_live(filename.clone(), Duration::minutes(5)).unwrap(); let key = Uuid::new_v4().to_string(); let value = String::from("Hello, world!"); @@ -366,7 +399,7 @@ mod tests { store_manual( filename.clone(), key.clone(), - value.clone(), + bitcode::serialize(&value).unwrap(), Utc::now() - Duration::minutes(5), ) .unwrap(); @@ -391,7 +424,8 @@ mod tests { rand::random::() )); - let cache = Cache::with_time_to_live(filename.clone(), Duration::minutes(5)).unwrap(); + let cache: Cache = + Cache::with_time_to_live(filename.clone(), Duration::minutes(5)).unwrap(); let key = Uuid::new_v4().to_string(); let value = String::from("Hello, world!"); @@ -399,7 +433,7 @@ mod tests { store_manual( filename.clone(), key.clone(), - value.clone(), + bitcode::serialize(&value).unwrap(), Utc::now() + Duration::minutes(15), ) .unwrap();