Skip to content

Commit

Permalink
Merge pull request #1 from lattice-based-cryptography/create_hom_add_…
Browse files Browse the repository at this point in the history
…unit_test

create homomorphic addition test
  • Loading branch information
jacksonwalters authored Dec 28, 2024
2 parents f2f0792 + a4345f3 commit 495baa7
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 42 deletions.
6 changes: 3 additions & 3 deletions src/decrypt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ use module_lwe::ring_mod::polysub;
pub fn decrypt(
sk: &Vec<Polynomial<i64>>, //secret key
q: i64, //ciphertext modulus
poly_mod: &Polynomial<i64>, //polynomial modulus
f: &Polynomial<i64>, //polynomial modulus
u: &Vec<Polynomial<i64>>, //ciphertext vector
v: &Polynomial<i64> //ciphertext polynomial
) -> Vec<i64> {
//Decrypt a ciphertext (u,v)
//Returns a plaintext vector

//Compute v-sk*u mod q
let scaled_pt = polysub(&v, &mul_vec_simple(&sk, &u, q, &poly_mod), q, &poly_mod);
let scaled_pt = polysub(&v, &mul_vec_simple(&sk, &u, q, &f), q, &f);
let half_q = (q as f64 / 2.0 + 0.5) as i64;
let mut decrypted_coeffs = vec![];
let mut s;
Expand Down Expand Up @@ -62,7 +62,7 @@ pub fn decrypt_string(sk_string: &String, ciphertext_string: &String, params: &P
let v = Polynomial::new(v_array.to_vec());

// Decrypt the ciphertext
let mut m_b = decrypt(&sk, q as i64, &f, &u, &v);
let mut m_b = decrypt(&sk, q, &f, &u, &v);
m_b.resize(n,0);

message_binary.extend(m_b);
Expand Down
29 changes: 15 additions & 14 deletions src/encrypt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,29 @@ pub fn encrypt(
a: &Vec<Vec<Polynomial<i64>>>,
t: &Vec<Polynomial<i64>>,
m_b: Vec<i64>,
f: &Polynomial<i64>,
q: i64,
r: &Vec<Polynomial<i64>>,
e1: &Vec<Polynomial<i64>>,
e2: &Polynomial<i64>
params: &Parameters,
seed: Option<u64>
) -> (Vec<Polynomial<i64>>, Polynomial<i64>) {

//get parameters
let (n, q, k, f) = (params.n, params.q, params.k, &params.f);

//generate random ephermal keys
let r = gen_small_vector(n, k, seed);
let e1 = gen_small_vector(n, k, seed);
let e2 = gen_small_vector(n, 1, seed)[0].clone(); // Single polynomial

//compute nearest integer to q/2
let half_q = (q as f64 / 2.0 + 0.5) as i64;

// Convert binary message to polynomial
let m = Polynomial::new(vec![half_q])*Polynomial::new(m_b);

// Compute u = a^T * r + e_1 mod q
let u = add_vec(&mul_mat_vec_simple(&transpose(a), r, q, f), e1, q, f);
let u = add_vec(&mul_mat_vec_simple(&transpose(a), &r, q, f), &e1, q, f);

// Compute v = t * r + e_2 - m mod q
let v = polysub(&polyadd(&mul_vec_simple(t, r, q, f), e2, q, f), &m, q, f);
let v = polysub(&polyadd(&mul_vec_simple(t, &r, q, &f), &e2, q, f), &m, q, f);

(u, v)
}
Expand All @@ -31,12 +37,7 @@ pub fn encrypt(
pub fn encrypt_string(pk_string: &String, message_string: &String, params: &Parameters, seed: Option<u64>) -> String {

//get parameters
let (n, q, k, f) = (params.n, params.q, params.k, &params.f);

// Randomly generated values for r, e1, and e2
let r = gen_small_vector(n, k, seed);
let e1 = gen_small_vector(n, k, seed);
let e2 = gen_small_vector(n, 1, seed)[0].clone(); // Single polynomial
let (n, k) = (params.n, params.k);

// Parse public key

Expand Down Expand Up @@ -69,7 +70,7 @@ pub fn encrypt_string(pk_string: &String, message_string: &String, params: &Para
// Encrypt each block
let mut ciphertext_list = vec![];
for block in message_blocks {
let (u, v) = encrypt(&a, &t, block, &f, q as i64, &r, &e1, &e2);
let (u, v) = encrypt(&a, &t, block, params, seed);
let u_flattened: Vec<i64> = u.iter()
.flat_map(|poly| {
let mut coeffs = poly.coeffs().to_vec();
Expand Down
30 changes: 12 additions & 18 deletions src/keygen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,33 @@ use module_lwe::{Parameters, add_vec, mul_mat_vec_simple, gen_small_vector, gen_
use std::collections::HashMap;

pub fn keygen(
size: usize, //polynomial modulus degree
modulus: i64, //ciphertext modulus
rank: usize, //module rank
poly_mod: &Polynomial<i64>, //polynomial modulus
params: &Parameters,
seed: Option<u64> //random seed
) -> (Vec<Vec<Polynomial<i64>>>, Vec<Polynomial<i64>>, Vec<Polynomial<i64>>) {
) -> ((Vec<Vec<Polynomial<i64>>>, Vec<Polynomial<i64>>), Vec<Polynomial<i64>>) {
let (n,q,k,f) = (params.n, params.q, params.k, &params.f);
//Generate a public and secret key
let a = gen_uniform_matrix(size, rank, modulus, seed);
let sk = gen_small_vector(size, rank, seed);
let e = gen_small_vector(size, rank, seed);
let t = add_vec(&mul_mat_vec_simple(&a, &sk, modulus, &poly_mod), &e, modulus, &poly_mod);
let a = gen_uniform_matrix(n, k, q, seed);
let sk = gen_small_vector(n, k, seed);
let e = gen_small_vector(n, k, seed);
let t = add_vec(&mul_mat_vec_simple(&a, &sk, q, &f), &e, q, &f);

//Return public key (a, t) and secret key (sk) as a 3-tuple
(a, t, sk)
//Return public key (a, t) and secret key (sk) as a 2-tuple
((a, t), sk)
}

//function to generate public/secret keys as key:value pairs
pub fn keygen_string(params: &Parameters, seed: Option<u64>) -> HashMap<String, String> {

//get parameters
let (n, q, k, f) = (params.n, params.q, params.k, &params.f);

//generate public, secret keys
let (a,t,sk) = keygen(n,q as i64,k,&f,seed);
let pk = (a,t);
let (pk,sk) = keygen(params,seed);

// Convert public key to a flattened list of coefficients
let mut pk_coeffs: Vec<i64> = pk.0
.iter()
.flat_map(|row| {
row.iter().flat_map(|poly| {
let mut coeffs = poly.coeffs().to_vec();
coeffs.resize(n, 0); // Resize to include leading zeros up to size `n`
coeffs.resize(params.n, 0); // Resize to include leading zeros up to size `n`
coeffs
})
})
Expand All @@ -44,7 +38,7 @@ pub fn keygen_string(params: &Parameters, seed: Option<u64>) -> HashMap<String,
pk.1.iter()
.flat_map(|poly| {
let mut coeffs = poly.coeffs().to_vec();
coeffs.resize(n, 0); // Resize to include leading zeros up to size `n`
coeffs.resize(params.n, 0); // Resize to include leading zeros up to size `n`
coeffs
})
);
Expand Down
6 changes: 3 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ use ring_mod::{polyadd, polymul, gen_uniform_poly};
#[derive(Debug)]
pub struct Parameters {
pub n: usize, // Polynomial modulus degree
pub q: usize, // Ciphertext modulus
pub q: i64, // Ciphertext modulus
pub k: usize, // Plaintext modulus
pub f: Polynomial<i64>, // Polynomial modulus (x^n + 1 representation)
}

impl Default for Parameters {
fn default() -> Self {
let n = 16;
let q = 67;
let n = 8;
let q = 59049;
let k = 2;
let mut poly_vec = vec![0i64;n+1];
poly_vec[0] = 1;
Expand Down
44 changes: 40 additions & 4 deletions src/test.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
#[cfg(test)] // This makes the following module compile only during tests
mod tests {
use crate::keygen::keygen_string;
use crate::encrypt::encrypt_string;
use crate::decrypt::decrypt_string;
use module_lwe::Parameters;
use crate::keygen::{keygen,keygen_string};
use crate::encrypt::{encrypt,encrypt_string};
use crate::decrypt::{decrypt,decrypt_string};
use module_lwe::{Parameters,add_vec};
use module_lwe::ring_mod::{polyadd};

// Test for basic keygen/encrypt/decrypt of a message
#[test]
Expand All @@ -18,4 +19,39 @@ mod tests {
let decrypted_message = decrypt_string(&sk_string, &ciphertext_string, &params);
assert_eq!(message, decrypted_message, "test failed: {} != {}", message, decrypted_message);
}

// Test homomorphic addition property:
// for plaintext polynomials m0, m1
// assert: dec(enc(m0) + enc(m1)) = m0 + m1
// note that the plaintext modulus is t=2 implicitly
// since we are using half_q = q/2 = q/t = delta
// the homomorphism is enc: R_q -> R_t, and dec: R_t -> R_q
#[test]
pub fn test_hom_add() {

let seed = None; //set the random seed
let params = Parameters::default();
let (n, q, f) = (params.n, params.q, &params.f);

let mut m0 = vec![1, 0, 1];
m0.resize(n, 0);
let mut m1 = vec![0, 0, 1];
m1.resize(n, 0);
let mut plaintext_sum = vec![1, 0, 0];
plaintext_sum.resize(n, 0);
let (pk, sk) = keygen(&params,seed);

// Encrypt plaintext messages
let u = encrypt(&pk.0, &pk.1, m0, &params, seed);
let v = encrypt(&pk.0, &pk.1, m1, &params, seed);

// Compute sum of encrypted data
let ciphertext_sum = (add_vec(&u.0,&v.0,q,f), polyadd(&u.1,&v.1,q,f));

// Decrypt ciphertext sum u+v
let mut decrypted_sum = decrypt(&sk, q, f, &ciphertext_sum.0, &ciphertext_sum.1);
decrypted_sum.resize(n, 0);

assert_eq!(decrypted_sum, plaintext_sum, "test failed: {:?} != {:?}", decrypted_sum, plaintext_sum);
}
}

0 comments on commit 495baa7

Please sign in to comment.