Skip to content

Commit

Permalink
feat(shortint): add modulus switch noise reduction in apply_programma…
Browse files Browse the repository at this point in the history
…ble_bootstrap
  • Loading branch information
mayeul-zama committed Jan 30, 2025
1 parent c465256 commit e646c3b
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 14 deletions.
20 changes: 18 additions & 2 deletions tfhe/src/shortint/list_compression/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use crate::shortint::ciphertext::CompressedCiphertextList;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::parameters::{CarryModulus, MessageModulus, NoiseLevel};
use crate::shortint::server_key::{
apply_programmable_bootstrap, generate_lookup_table_with_encoding, unchecked_scalar_mul_assign,
apply_programmable_bootstrap_no_ms_noise_reduction, generate_lookup_table_with_encoding,
unchecked_scalar_mul_assign, ShortintBootstrappingKey,
};
use crate::shortint::{Ciphertext, CiphertextModulus, MaxNoiseLevel};
use rayon::iter::ParallelIterator;
Expand Down Expand Up @@ -201,14 +202,29 @@ impl DecompressionKey {
ciphertext_modulus,
);

match &self.blind_rotate_key {
ShortintBootstrappingKey::Classic {
bsk: _,
modulus_switch_noise_reduction_key,
} => {
assert!(
modulus_switch_noise_reduction_key.is_none(),
"Decompression key should not do modulus switch noise reduction"
);
}
ShortintBootstrappingKey::MultiBit { .. } => {
panic!("Decompression can't use a multi bit PBS")
}
}

ShortintEngine::with_thread_local_mut(|engine| {
let (_ciphertext_buffers, buffers) = engine.get_buffers_no_sk(
self.blind_rotate_key.input_lwe_dimension(),
self.blind_rotate_key.output_lwe_dimension(),
CiphertextModulus::new_native(),
);

apply_programmable_bootstrap(
apply_programmable_bootstrap_no_ms_noise_reduction(
&self.blind_rotate_key,
&intermediate_lwe,
&mut output_br,
Expand Down
6 changes: 4 additions & 2 deletions tfhe/src/shortint/oprf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ use crate::core_crypto::prelude::{
use crate::shortint::ciphertext::Degree;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::parameters::NoiseLevel;
use crate::shortint::server_key::{apply_programmable_bootstrap, LookupTableOwned};
use crate::shortint::server_key::{
apply_programmable_bootstrap_no_ms_noise_reduction, LookupTableOwned,
};
use crate::shortint::{PBSOrder, ServerKey};
use tfhe_csprng::seeders::Seed;

Expand Down Expand Up @@ -146,7 +148,7 @@ impl ServerKey {
ShortintEngine::with_thread_local_mut(|engine| {
let (_, buffers) = engine.get_buffers(self);

apply_programmable_bootstrap(
apply_programmable_bootstrap_no_ms_noise_reduction(
&self.bootstrapping_key,
&seeded,
&mut ct,
Expand Down
85 changes: 79 additions & 6 deletions tfhe/src/shortint/server_key/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub(crate) use scalar_mul::unchecked_scalar_mul_assign;
pub(crate) mod tests;

use crate::conformance::ParameterSetConformant;
use crate::core_crypto::algorithms::modulus_switch_noise_reduction::improve_lwe_ciphertext_modulus_switch_noise_for_binary_key;
use crate::core_crypto::algorithms::*;
use crate::core_crypto::commons::parameters::{
DecompositionBaseLog, DecompositionLevelCount, GlweDimension, GlweSize, LweBskGroupingFactor,
Expand All @@ -40,7 +41,7 @@ use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use crate::core_crypto::fft_impl::fft64::crypto::bootstrap::BootstrapKeyConformanceParams;
use crate::core_crypto::fft_impl::fft64::math::fft::Fft;
use crate::core_crypto::prelude::ComputationBuffers;
use crate::core_crypto::prelude::{CiphertextModulusLog, ComputationBuffers};
use crate::shortint::ciphertext::{Ciphertext, Degree, MaxDegree, MaxNoiseLevel, NoiseLevel};
use crate::shortint::client_key::ClientKey;
use crate::shortint::engine::{
Expand Down Expand Up @@ -1512,22 +1513,77 @@ impl ServerKey {
}
}

pub(crate) fn apply_blind_rotate<Scalar, InputCont, OutputCont>(
pub(crate) fn apply_blind_rotate<InputCont, OutputCont>(
bootstrapping_key: &ShortintBootstrappingKey,
in_buffer: &LweCiphertext<InputCont>,
acc: &mut GlweCiphertext<OutputCont>,
buffers: &mut ComputationBuffers,
) where
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
InputCont: Container<Element = u64>,
OutputCont: ContainerMut<Element = u64>,
{
let input_improved_before_ms;

let input_blind_rotate = match bootstrapping_key {
ShortintBootstrappingKey::Classic {
bsk,
modulus_switch_noise_reduction_key,
} if modulus_switch_noise_reduction_key.is_some() => {
input_improved_before_ms = apply_modulus_switch_noise_reduction(
modulus_switch_noise_reduction_key.as_ref().unwrap(),
bsk.polynomial_size().to_blind_rotation_input_modulus_log(),
in_buffer,
);

input_improved_before_ms.as_view()
}
_ => in_buffer.as_view(),
};

apply_blind_rotate_no_ms_noise_reduction(bootstrapping_key, &input_blind_rotate, acc, buffers);
}

pub(crate) fn apply_modulus_switch_noise_reduction<InputCont>(
modulus_switch_noise_reduction_key: &ModulusSwitchNoiseReductionKey,
log_modulus: CiphertextModulusLog,
in_buffer: &LweCiphertext<InputCont>,
) -> LweCiphertext<Vec<u64>>
where
InputCont: Container<Element = u64>,
{
let ciphertext_modulus = in_buffer.ciphertext_modulus();
let mut input: LweCiphertext<Vec<u64>> = LweCiphertext::from_container(
in_buffer.as_view().into_container().to_owned(),
ciphertext_modulus,
);

improve_lwe_ciphertext_modulus_switch_noise_for_binary_key(
&mut input,
&modulus_switch_noise_reduction_key.modulus_switch_zeros,
modulus_switch_noise_reduction_key.ms_r_sigma_factor,
modulus_switch_noise_reduction_key.ms_bound,
log_modulus,
);

input
}

pub(crate) fn apply_blind_rotate_no_ms_noise_reduction<InputCont, OutputCont>(
bootstrapping_key: &ShortintBootstrappingKey,
in_buffer: &LweCiphertext<InputCont>,
acc: &mut GlweCiphertext<OutputCont>,
buffers: &mut ComputationBuffers,
) where
InputCont: Container<Element = u64>,
OutputCont: ContainerMut<Element = u64>,
{
#[cfg(feature = "pbs-stats")]
let _ = PBS_COUNT.fetch_add(1, Ordering::Relaxed);

match bootstrapping_key {
ShortintBootstrappingKey::Classic {
bsk: fourier_bsk, ..
bsk: fourier_bsk,
modulus_switch_noise_reduction_key: _,
} => {
let fft = Fft::new(fourier_bsk.polynomial_size());
let fft = fft.as_view();
Expand Down Expand Up @@ -1578,6 +1634,23 @@ pub(crate) fn apply_programmable_bootstrap<InputCont, OutputCont>(
extract_lwe_sample_from_glwe_ciphertext(&glwe_out, out_buffer, MonomialDegree(0));
}

pub(crate) fn apply_programmable_bootstrap_no_ms_noise_reduction<InputCont, OutputCont>(
bootstrapping_key: &ShortintBootstrappingKey,
in_buffer: &LweCiphertext<InputCont>,
out_buffer: &mut LweCiphertext<OutputCont>,
acc: &GlweCiphertext<Vec<u64>>,
buffers: &mut ComputationBuffers,
) where
InputCont: Container<Element = u64>,
OutputCont: ContainerMut<Element = u64>,
{
let mut glwe_out: GlweCiphertext<_> = acc.clone();

apply_blind_rotate_no_ms_noise_reduction(bootstrapping_key, in_buffer, &mut glwe_out, buffers);

extract_lwe_sample_from_glwe_ciphertext(&glwe_out, out_buffer, MonomialDegree(0));
}

pub fn generate_lookup_table<F>(
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
Expand Down
8 changes: 4 additions & 4 deletions tfhe/src/shortint/server_key/modulus_switched_compression.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::compressed_modulus_switched_multi_bit_lwe_ciphertext::CompressedModulusSwitchedMultiBitLweCiphertext;
use super::{
extract_lwe_sample_from_glwe_ciphertext, multi_bit_deterministic_blind_rotate_assign,
GlweCiphertext, ShortintBootstrappingKey,
apply_programmable_bootstrap_no_ms_noise_reduction, extract_lwe_sample_from_glwe_ciphertext,
multi_bit_deterministic_blind_rotate_assign, GlweCiphertext, ShortintBootstrappingKey,
};
use crate::core_crypto::commons::parameters::MonomialDegree;
use crate::core_crypto::prelude::compressed_modulus_switched_lwe_ciphertext::CompressedModulusSwitchedLweCiphertext;
Expand All @@ -10,7 +10,7 @@ use crate::shortint::ciphertext::{
CompressedModulusSwitchedCiphertext, InternalCompressedModulusSwitchedCiphertext, NoiseLevel,
};
use crate::shortint::engine::ShortintEngine;
use crate::shortint::server_key::{apply_programmable_bootstrap, LookupTableOwned};
use crate::shortint::server_key::LookupTableOwned;
use crate::shortint::{Ciphertext, PBSOrder, ServerKey};

impl ServerKey {
Expand Down Expand Up @@ -139,7 +139,7 @@ impl ServerKey {
panic!("Compression was done targeting a MultiBit bootstrap decompression, cannot decompress with a Classic bootstrapping key")
}
};
apply_programmable_bootstrap(
apply_programmable_bootstrap_no_ms_noise_reduction(
&self.bootstrapping_key,
&ct,
&mut ciphertext_buffers.buffer_lwe_after_pbs,
Expand Down

0 comments on commit e646c3b

Please sign in to comment.