Skip to content

Commit

Permalink
fix modswitch error criteria and first FNT code
Browse files Browse the repository at this point in the history
  • Loading branch information
nindanaoto committed Jan 5, 2025
1 parent cef10bf commit a66369c
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 2 deletions.
132 changes: 132 additions & 0 deletions include/fnt.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#pragma once
#include <cstdint>
#include <array>
#include <span>

namespace FNTpp{
constexpr unsigned int Kbit = 5;
constexpr unsigned int K = 1 << Kbit;
constexpr int64_t P = (1ULL << K) + 1;
constexpr int64_t wordmask = (1ULL << K) - 1;

template <uint8_t bit>
uint32_t BitReverse(uint32_t in)
{
if constexpr (bit > 1) {
const uint32_t center = in & ((bit & 1) << (bit / 2));
return (BitReverse<bit / 2>(in & ((1U << (bit / 2)) - 1))
<< (bit + 1) / 2) |
center | BitReverse<bit / 2>(in >> ((bit + 1) / 2));
}
else {
return in;
}
}

static inline int64_t ModLshift(int64_t a, uint8_t b)
{
// If b >= 32, multiply by 2^32 ≡ -1 (mod P).
// => a = P - a (unless a == 0), then reduce b by 32.
if (b >= 32) {
if (a != 0) {
a = P - a;
}
b -= 32; // now b < 32
}

// Shift by b < 32 in 64-bit arithmetic (safe from overflow).
int64_t r = a << b;

// Now reduce a modulo P:
// hi = upper 32 bits
// lo = lower 32 bits
// Since (hi << 32) ≡ -hi (mod P),
// we can do (lo + hi) mod P and then subtract P if needed.
const int64_t hi = r >> K;
const int64_t lo = r & wordmask;
r = -hi + lo;

// Subtract P once or twice if needed to ensure a < P
if (r < 0) r += P;
if (r >= P) r -= P;
return r;
}

template <uint8_t Nbit>
inline void MulInvN(std::array<int64_t, 1u<<Nbit>& a){
for(int i = 0; i < (1u<<Nbit); i++) a[i] = ModLshift(a[i], 2*K-Nbit);
}

template <unsigned int Nbit>
void FNT(const std::span<int64_t, 1u << Nbit> res)
{
if constexpr (Nbit == 1){
const int64_t temp = res[0];
res[0] += res[1];
if(res[0] >= P) res[0] -= P;
res[1] = temp - res[1];
if(res[1] < 0) res[1] += P;
}else{
constexpr unsigned int N = 1u << Nbit;
constexpr unsigned int stride = 1u << (Kbit+1 - Nbit);
for(unsigned int i = 0; i < N/2; i++){
const int64_t temp = res[i]+res[i+N/2];
res[i+N/2] = res[i]-res[i+N/2];
if(res[i+N/2] < 0) res[i+N/2] += P;
if(i!=0) res[i+N/2] = ModLshift(res[i+N/2],i*stride);
res[i] = temp >= P ? temp - P : temp;
}
FNT<Nbit-1>(res.template subspan<0,N/2>());
FNT<Nbit-1>(res.template subspan<N/2,N/2>());
}
}

template <unsigned int Nbit>
void IFNT(const std::span<int64_t, 1u << Nbit> res)
{
if constexpr (Nbit == 1){
const int64_t temp = res[0];
res[0] += res[1];
if(res[0] >= P) res[0] -= P;
res[1] = temp - res[1];
if(res[1] < 0) res[1] += P;
}else{
constexpr unsigned int N = 1u << Nbit;
IFNT<Nbit-1>(res.template subspan<0,N/2>());
IFNT<Nbit-1>(res.template subspan<N/2,N/2>());
constexpr unsigned int stride = 1u << (Kbit+1 - Nbit);
for(unsigned int i = 0; i < N/2; i++){
if(i!=0) res[i+N/2] = ModLshift(res[i+N/2],(N-i)*stride);
const int64_t temp = res[i]+res[i+N/2];
res[i+N/2] = res[i]-res[i+N/2]; //Part of twiddle factor
if(res[i+N/2] < 0) res[i+N/2] += P;
res[i] = temp >= P ? temp - P : temp;
}
}
}


template <unsigned int Nbit>
void TwistFNT(std::array<int64_t, 1u << (Nbit+1)> &res, const std::array<int64_t, 1u << Nbit> &a)
{
constexpr unsigned int formersizebit = (Nbit + 1)/2;
static_assert(formersizebit <= Kbit, "sizebit must be less than or equal to Kbit");
constexpr unsigned int formersize = 1u << formersizebit;
constexpr unsigned int latersizebit = (Nbit + 1) - formersizebit;
constexpr unsigned int latersize = 1u << latersizebit;
constexpr unsigned int formerrbit = (Kbit+1) - formersizebit;
constexpr unsigned int laterrbit = (Kbit+1) - latersizebit;
//Former
for(unsigned int i = 0; i < latersize/2; i++){
std::array<int64_t, formersize> temp;
for(unsigned int j = 0; j < formersize; j++)
temp[j] = ModLshift(a[j*(latersize/2) + i],(j*(latersize/2) + i)<<(formerrbit-1));
FNT<formersizebit>(std::span{temp});
for(unsigned int j = 0; j < formersize; j++)
res[j*latersize + i] = temp[j];
}
//Later
for(unsigned int i = 0; i < formersize; i++)
FNT<latersizebit>(std::span{res}.subspan(i,latersize));
}
}
5 changes: 5 additions & 0 deletions include/raintt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ constexpr T ipow(T num, unsigned int pow)
: pow == 0 ? 1
: num * ipow(num, pow - 1);
}
#ifdef USE_COMPRESS
constexpr uint min_wordbits = 27;
#else
constexpr uint min_wordbits = 31;
#endif

#ifdef __clang__
// Currently _BigInt is only implemented in clang
Expand Down
47 changes: 47 additions & 0 deletions test/fnt.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include<fnt.hpp>
#include<random>
#include<array>
#include<iostream>
#include<cassert>

int main(){
constexpr uint32_t num_test = 1000;
constexpr unsigned int Nbit = 6;
constexpr unsigned int N = 1u << Nbit;

std::random_device seed_gen;
std::default_random_engine engine(seed_gen());
std::uniform_int_distribution<int64_t> Pdist(0, FNTpp::P);

std::cout<< "Start ModLShift Test"<< std::endl;
for(int test = 0; test < num_test; test++){
const int64_t a = Pdist(engine);
const uint shift = std::uniform_int_distribution<uint>(0, 63)(engine);
const int64_t res = FNTpp::ModLshift(a, shift);
if(res != (static_cast<__int128_t>(a) << shift) % FNTpp::P)
std::cout << "a: " << a << " shift: " << shift << " res: " << res << " expected: " << static_cast<int64_t>((static_cast<__int128_t>(a) << shift) % FNTpp::P) << std::endl;
assert(res == (static_cast<__int128_t>(a) << shift) % FNTpp::P);
}
std::cout<< "Passed ModLShift"<< std::endl;

std::cout << "invN Test" << std::endl;
assert(1 == FNTpp::ModLshift(N, 2*FNTpp::K-Nbit));
std::cout << "Passed invN" << std::endl;

std::cout << "Start univariable FNT only test." << std::endl;
for(int test = 0; test < num_test; test++){
std::array<int64_t, N> a;
for(int i = 0; i < N; i++) a[i] = Pdist(engine);
std::array<int64_t, N> res;
res = a;
FNTpp::FNT<Nbit>(res);
FNTpp::IFNT<Nbit>(res);
FNTpp::MulInvN<Nbit>(res);
for(int i = 0; i < N; i++)
if(a[i] != res[i]) std::cout << "i: "<< i << " a: " << a[i] << " res: " << res[i] << std::endl;
for(int i = 0; i < N; i++) assert(a[i] == res[i]);
}
std::cout << "Univariable FNT only test Passed" << std::endl;

return 0;
}
4 changes: 2 additions & 2 deletions test/raintt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ int main()
(1ULL << (raintt::wordbits - 1 - 1))) >>
(raintt::wordbits - 1);
assert(std::abs(static_cast<int>(a - c)) <=
(1 << (32 - raintt::wordbits + 1)));
((1U << raintt::min_wordbits)/raintt::P)+1);
}
std::cout << "Modswitch Passed" << std::endl;
for (int test = 0; test < num_test; test++) {
Expand All @@ -215,7 +215,7 @@ int main()
// 4)std::cout<<res[i]<<":"<<a[i]<<std::endl;
for (int i = 0; i < TFHEpp::lvl1param::n; i++)
assert(std::abs(static_cast<int>(res[i] - a[i])) <=
(1 << (32 - raintt::wordbits + 1)));
((1U << raintt::min_wordbits)/raintt::P)+1);
}
std::cout << "NTT with modswitch Passed" << std::endl;

Expand Down

0 comments on commit a66369c

Please sign in to comment.