Skip to content

Commit

Permalink
Avoid unintended use of AVX512
Browse files Browse the repository at this point in the history
  • Loading branch information
nindanaoto committed Oct 22, 2024
1 parent 1eb3d57 commit 21a7480
Showing 1 changed file with 106 additions and 10 deletions.
116 changes: 106 additions & 10 deletions thirdparties/spqlios/fft_processor_spqlios.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,36 +38,65 @@ FFT_Processor_Spqlios::FFT_Processor_Spqlios(const int32_t N) : _2N(2 * N), N(N)
}

void FFT_Processor_Spqlios::execute_reverse_uint(double *res, const uint32_t *a) {
//for (int32_t i=0; i<N; i++) real_inout_rev[i]=(double)a[i];
#ifdef USE_AVX512
{
double *dst = res;
// double *dst = real_inout_rev;
const uint32_t *ait = a;
const uint32_t *aend = a + N;
// __asm__ __volatile__ (
// "0:\n"
// "vmovupd (%1),%%xmm0\n"
// "vcvtudq2pd %%xmm0,%%ymm1\n"
// "vmovapd %%ymm1,(%0)\n"
// "addq $16,%1\n"
// "addq $32,%0\n"
// "cmpq %2,%1\n"
// "jb 0b\n"
// : "=r"(dst), "=r"(ait), "=r"(aend)
// : "0"(dst), "1"(ait), "2"(aend)
// : "%xmm0", "%ymm1", "memory"
// );
__asm__ __volatile__ (
"0:\n"
"vmovupd (%1),%%xmm0\n"
"vcvtudq2pd %%xmm0,%%ymm1\n"
"vmovupd (%1),%%ymm0\n"
"vcvtudq2pd %%ymm0,%%zmm1\n"
"vmovapd %%ymm1,(%0)\n"
"addq $16,%1\n"
"addq $32,%0\n"
"addq $32,%1\n"
"addq $64,%0\n"
"cmpq %2,%1\n"
"jb 0b\n"
: "=r"(dst), "=r"(ait), "=r"(aend)
: "0"(dst), "1"(ait), "2"(aend)
: "%xmm0", "%ymm1", "memory"
: "%ymm0", "%zmm1", "memory"
);
}
#else
for (int32_t i=0; i<N; i++) res[i]=(double)a[i];
#endif
ifft(tables_reverse, res);
}

void FFT_Processor_Spqlios::execute_reverse_int(double *res, const int32_t *a) {
//for (int32_t i=0; i<N; i++) real_inout_rev[i]=(double)a[i];
{
double *dst = res;
// double *dst = real_inout_rev;
const int32_t *ait = a;
const int32_t *aend = a + N;
#ifdef USE_AVX512
__asm__ __volatile__ (
"0:\n"
"vmovdqu32 (%1),%%zmm0\n" // Load 16 int32_t values from `ait` into zmm0
"vcvtdq2pd %%zmm0,%%zmm1\n" // Convert 16 int32_t values to 8 double-precision values
"vmovapd %%zmm1,(%0)\n" // Store the result (8 doubles) in `dst`
"addq $64,%1\n" // Increment `ait` by 64 bytes (16 int32_t values)
"addq $64,%0\n" // Increment `dst` by 64 bytes (8 double-precision values)
"cmpq %2,%1\n" // Compare `ait` with `aend`
"jb 0b\n" // Jump back if `ait < aend`
: "=r"(dst), "=r"(ait), "=r"(aend)
: "0"(dst), "1"(ait), "2"(aend)
: "%zmm0", "%zmm1", "memory"
);
#else
__asm__ __volatile__ (
"0:\n"
"vmovupd (%1),%%xmm0\n"
Expand All @@ -81,6 +110,7 @@ void FFT_Processor_Spqlios::execute_reverse_int(double *res, const int32_t *a) {
: "0"(dst), "1"(ait), "2"(aend)
: "%xmm0", "%ymm1", "memory"
);
#endif
}
ifft(tables_reverse, res);
}
Expand Down Expand Up @@ -110,8 +140,23 @@ void FFT_Processor_Spqlios::execute_direct_torus32(uint32_t *res, const double *
double *dst = real_inout_direct;
const double *sit = a;
const double *send = a + N;
//double __2sN = 2./N;
const double *bla = &_2sN;
#ifdef AVX512
__asm__ __volatile__ (
"vbroadcastsd (%3),%%zmm2\n" // Broadcast _2sN to zmm2
"1:\n"
"vmovupd (%1),%%zmm0\n" // Load 8 double-precision values from `sit` into zmm0
"vmulpd %%zmm2,%%zmm0,%%zmm0\n" // Multiply zmm0 by zmm2
"vmovupd %%zmm0,(%0)\n" // Store the result in `dst`
"addq $64,%1\n" // Increment `sit` by 64 bytes (8 doubles)
"addq $64,%0\n" // Increment `dst` by 64 bytes (8 doubles)
"cmpq %2,%1\n" // Compare `sit` with `send`
"jb 1b\n" // Jump if `sit` < `send`
: "=r"(dst), "=r"(sit), "=r"(send), "=r"(bla)
: "0"(dst), "1"(sit), "2"(send), "3"(bla)
: "%zmm0", "%zmm2", "memory"
);
#else
__asm__ __volatile__ (
"vbroadcastsd (%3),%%ymm2\n"
"1:\n"
Expand All @@ -126,6 +171,7 @@ void FFT_Processor_Spqlios::execute_direct_torus32(uint32_t *res, const double *
: "0"(dst), "1"(sit), "2"(send), "3"(bla)
: "%ymm0", "%ymm2", "memory"
);
#endif
}
fft(tables_direct, real_inout_direct);
// for (int32_t i = 0; i < N; i++) res[i] = uint32_t(int64_t(real_inout_direct[i]));
Expand All @@ -142,6 +188,22 @@ void FFT_Processor_Spqlios::execute_direct_torus32_q(uint32_t *res, const double
const double *send = a + N;
//double __2sN = 2./N;
const double *bla = &_2sN;
#ifdef USE_AVX512
__asm__ __volatile__ (
"vbroadcastsd (%3),%%zmm2\n" // Broadcast _2sN to zmm2
"1:\n"
"vmovupd (%1),%%zmm0\n" // Load 8 double-precision values from `sit` into zmm0
"vmulpd %%zmm2,%%zmm0,%%zmm0\n" // Multiply zmm0 by zmm2
"vmovupd %%zmm0,(%0)\n" // Store the result in `dst`
"addq $64,%1\n" // Increment `sit` by 64 bytes (8 doubles)
"addq $64,%0\n" // Increment `dst` by 64 bytes (8 doubles)
"cmpq %2,%1\n" // Compare `sit` with `send`
"jb 1b\n" // Jump if `sit` < `send`
: "=r"(dst), "=r"(sit), "=r"(send), "=r"(bla)
: "0"(dst), "1"(sit), "2"(send), "3"(bla)
: "%zmm0", "%zmm2", "memory"
);
#else
__asm__ __volatile__ (
"vbroadcastsd (%3),%%ymm2\n"
"1:\n"
Expand All @@ -156,6 +218,7 @@ void FFT_Processor_Spqlios::execute_direct_torus32_q(uint32_t *res, const double
: "0"(dst), "1"(sit), "2"(send), "3"(bla)
: "%ymm0", "%ymm2", "memory"
);
#endif
}
fft(tables_direct, real_inout_direct);
for (int32_t i = 0; i < N; i++) res[i] = uint32_t((int64_t(real_inout_direct[i])%q+q)%q);
Expand All @@ -169,8 +232,23 @@ void FFT_Processor_Spqlios::execute_direct_torus32_rescale(uint32_t *res, const
double *dst = real_inout_direct;
const double *sit = a;
const double *send = a + N;
//double __2sN = 2./N;
const double *bla = &_2sN;
#ifdef USE_AVX512
__asm__ __volatile__ (
"vbroadcastsd (%3),%%zmm2\n" // Broadcast _2sN to zmm2
"1:\n"
"vmovupd (%1),%%zmm0\n" // Load 8 double-precision values from `sit` into zmm0
"vmulpd %%zmm2,%%zmm0,%%zmm0\n" // Multiply zmm0 by zmm2
"vmovupd %%zmm0,(%0)\n" // Store the result in `dst`
"addq $64,%1\n" // Increment `sit` by 64 bytes (8 doubles)
"addq $64,%0\n" // Increment `dst` by 64 bytes (8 doubles)
"cmpq %2,%1\n" // Compare `sit` with `send`
"jb 1b\n" // Jump if `sit` < `send`
: "=r"(dst), "=r"(sit), "=r"(send), "=r"(bla)
: "0"(dst), "1"(sit), "2"(send), "3"(bla)
: "%zmm0", "%zmm2", "memory"
);
#else
__asm__ __volatile__ (
"vbroadcastsd (%3),%%ymm2\n"
"1:\n"
Expand All @@ -185,6 +263,7 @@ void FFT_Processor_Spqlios::execute_direct_torus32_rescale(uint32_t *res, const
: "0"(dst), "1"(sit), "2"(send), "3"(bla)
: "%ymm0", "%ymm2", "memory"
);
#endif
}
fft(tables_direct, real_inout_direct);
for (int32_t i = 0; i < N; i++) res[i] = static_cast<uint32_t>(int64_t(real_inout_direct[i]/Δ));
Expand All @@ -200,6 +279,22 @@ void FFT_Processor_Spqlios::execute_direct_torus64(uint64_t* res, const double*
const double* send = a+N;
//double __2sN = 2./N;
const double* bla = &_2sN;
#ifdef USE_AVX512
__asm__ __volatile__ (
"vbroadcastsd (%3),%%zmm2\n" // Broadcast 2sN to zmm2
"1:\n"
"vmovupd (%1),%%zmm0\n" // Load 8 double-precision floats from `sit` into zmm0
"vmulpd %%zmm2,%%zmm0,%%zmm0\n" // Multiply the vector by zmm2
"vmovapd %%zmm0,(%0)\n" // Store the result into `dst`
"addq $64,%1\n" // Increment `sit` by 64 (8 doubles * 8 bytes per double)
"addq $64,%0\n" // Increment `dst` by 64 (8 doubles * 8 bytes per double)
"cmpq %2,%1\n" // Compare `sit` with `send`
"jb 1b\n" // Jump back if not done
: "=r"(dst), "=r"(sit), "=r"(send), "=r"(bla)
: "0"(dst), "1"(sit), "2"(send), "3"(bla)
: "%zmm0", "%zmm2", "memory"
);
#else
__asm__ __volatile__ (
"vbroadcastsd (%3),%%ymm2\n"
"1:\n"
Expand All @@ -214,6 +309,7 @@ void FFT_Processor_Spqlios::execute_direct_torus64(uint64_t* res, const double*
: "0"(dst),"1"(sit),"2"(send),"3"(bla)
: "%ymm0","%ymm2","memory"
);
#endif
}
fft(tables_direct,real_inout_direct);
#ifdef USE_AVX512
Expand Down

0 comments on commit 21a7480

Please sign in to comment.