Skip to content

Commit

Permalink
fix(gpu): fix default multi-bit PBS for multi-device execution of int…
Browse files Browse the repository at this point in the history
…eger ops
  • Loading branch information
pdroalves authored and agnesLeroy committed Jan 31, 2025
1 parent cc6edd0 commit c470b71
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ template <typename Torus> struct pbs_buffer<Torus, PBS_TYPE::MULTI_BIT> {
uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, uint32_t lwe_chunk_size,
PBS_VARIANT pbs_variant, bool allocate_gpu_memory) {
cudaSetDevice(gpu_index);

this->pbs_variant = pbs_variant;
this->lwe_chunk_size = lwe_chunk_size;
auto max_shared_memory = cuda_get_max_shared_memory(gpu_index);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ __host__ void scratch_cg_multi_bit_programmable_bootstrap(
uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) {

cudaSetDevice(gpu_index);

uint64_t full_sm_keybundle =
get_buffer_size_full_sm_multibit_programmable_bootstrap_keybundle<Torus>(
polynomial_size);
Expand Down Expand Up @@ -296,6 +298,7 @@ __host__ void execute_cg_external_product_loop(
uint32_t polynomial_size, uint32_t grouping_factor, uint32_t base_log,
uint32_t level_count, uint32_t lwe_offset, uint32_t num_many_lut,
uint32_t lut_stride) {
cudaSetDevice(gpu_index);

uint64_t full_sm =
get_buffer_size_full_sm_cg_multibit_programmable_bootstrap<Torus>(
Expand All @@ -310,7 +313,6 @@ __host__ void execute_cg_external_product_loop(

auto lwe_chunk_size = buffer->lwe_chunk_size;
int max_shared_memory = cuda_get_max_shared_memory(gpu_index);
cudaSetDevice(gpu_index);

uint32_t keybundle_size_per_input =
lwe_chunk_size * level_count * (glwe_dimension + 1) *
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,8 @@ __host__ void scratch_multi_bit_programmable_bootstrap(
uint32_t polynomial_size, uint32_t level_count,
uint32_t input_lwe_ciphertext_count, bool allocate_gpu_memory) {

cudaSetDevice(gpu_index);

int max_shared_memory = cuda_get_max_shared_memory(gpu_index);
uint64_t full_sm_keybundle =
get_buffer_size_full_sm_multibit_programmable_bootstrap_keybundle<Torus>(
Expand Down Expand Up @@ -494,6 +496,7 @@ __host__ void execute_compute_keybundle(
pbs_buffer<Torus, MULTI_BIT> *buffer, uint32_t num_samples,
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t grouping_factor, uint32_t level_count, uint32_t lwe_offset) {
cudaSetDevice(gpu_index);

auto lwe_chunk_size = buffer->lwe_chunk_size;
uint32_t chunk_size =
Expand All @@ -507,7 +510,6 @@ __host__ void execute_compute_keybundle(
get_buffer_size_full_sm_multibit_programmable_bootstrap_keybundle<Torus>(
polynomial_size);
int max_shared_memory = cuda_get_max_shared_memory(gpu_index);
cudaSetDevice(gpu_index);

auto d_mem = buffer->d_mem_keybundle;
auto keybundle_fft = buffer->keybundle_fft;
Expand Down Expand Up @@ -543,6 +545,7 @@ execute_step_one(cudaStream_t stream, uint32_t gpu_index,
uint32_t lwe_dimension, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log,
uint32_t level_count, uint32_t j, uint32_t lwe_offset) {
cudaSetDevice(gpu_index);

uint64_t full_sm_accumulate_step_one =
get_buffer_size_full_sm_multibit_programmable_bootstrap_step_one<Torus>(
Expand All @@ -551,7 +554,6 @@ execute_step_one(cudaStream_t stream, uint32_t gpu_index,
get_buffer_size_partial_sm_multibit_programmable_bootstrap_step_one<
Torus>(polynomial_size);
int max_shared_memory = cuda_get_max_shared_memory(gpu_index);
cudaSetDevice(gpu_index);

//
auto d_mem = buffer->d_mem_acc_step_one;
Expand Down Expand Up @@ -599,13 +601,13 @@ execute_step_two(cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out,
uint32_t polynomial_size, int32_t grouping_factor,
uint32_t level_count, uint32_t j, uint32_t lwe_offset,
uint32_t num_many_lut, uint32_t lut_stride) {
cudaSetDevice(gpu_index);

auto lwe_chunk_size = buffer->lwe_chunk_size;
uint64_t full_sm_accumulate_step_two =
get_buffer_size_full_sm_multibit_programmable_bootstrap_step_two<Torus>(
polynomial_size);
int max_shared_memory = cuda_get_max_shared_memory(gpu_index);
cudaSetDevice(gpu_index);

auto d_mem = buffer->d_mem_acc_step_two;
auto keybundle_fft = buffer->keybundle_fft;
Expand Down

0 comments on commit c470b71

Please sign in to comment.