Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up pack reordering #5782

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 76 additions & 49 deletions src/nnue/nnue_feature_transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <cstdint>
#include <cstring>
#include <iosfwd>
#include <type_traits>
#include <utility>

#include "../position.h"
Expand Down Expand Up @@ -146,6 +147,72 @@ using psqt_vec_t = int32x4_t;
#endif


class Packing {
xu-shawn marked this conversation as resolved.
Show resolved Hide resolved
private:
// Store the order by which 128-bit blocks of a 1024-bit data must
// be permuted so that calling packus on adjacent vectors of 16-bit
// integers loaded from the data results in the pre-permutation order
static constexpr auto packus_epi16_order = []() -> std::array<std::size_t, 8> {
#if defined(USE_AVX512)
// _mm512_packus_epi16 after permutation:
// | 0 | 2 | 4 | 6 | // Vector 0
// | 1 | 3 | 5 | 7 | // Vector 1
// | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | // Packed Result
return {0, 2, 4, 6, 1, 3, 5, 7};
#elif defined(USE_AVX2)
// _mm256_packus_epi16 after permutation:
// | 0 | 2 | | 4 | 6 | // Vector 0, 2
// | 1 | 3 | | 5 | 7 | // Vector 1, 3
// | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | // Packed Result
return {0, 2, 1, 3, 4, 6, 5, 7};
#else
return {0, 1, 2, 3, 4, 5, 6, 7};
#endif
}();

// The accumulator values are 16 bits wide, so there are 8
// values packed in every 128 bits.
static constexpr std::size_t epi16_chunk_size = 8;

public:
static constexpr std::size_t num_elements = epi16_chunk_size * packus_epi16_order.size();

static constexpr auto permute_for_packus_epi16 = [](auto* const values) {
std::array<std::remove_pointer_t<decltype(values)>,
epi16_chunk_size * packus_epi16_order.size()>
buffer;

for (std::size_t i = 0; i < packus_epi16_order.size(); i++)
{
auto* const buffer_chunk = &buffer[i * epi16_chunk_size];
auto* const value_chunk = &values[packus_epi16_order[i] * epi16_chunk_size];

std::copy(value_chunk, value_chunk + epi16_chunk_size, buffer_chunk);
}

std::copy(std::begin(buffer), std::end(buffer), values);
};

static constexpr auto unpermute_for_packus_epi16 = [](auto* const values) {
std::array<std::remove_pointer_t<decltype(values)>,
epi16_chunk_size * packus_epi16_order.size()>
buffer;

for (std::size_t i = 0; i < packus_epi16_order.size(); i++)
{
auto* const buffer_chunk = &buffer[packus_epi16_order[i] * epi16_chunk_size];
auto* const value_chunk = &values[i * epi16_chunk_size];

std::copy(value_chunk, value_chunk + epi16_chunk_size, buffer_chunk);
}

std::copy(std::begin(buffer), std::end(buffer), values);
};

~Packing() = delete;
};


// Compute optimal SIMD register count for feature transformer accumulation.
template<IndexType TransformedFeatureWidth, IndexType HalfDimensions>
class SIMDTiling {
Expand Down Expand Up @@ -228,57 +295,17 @@ class FeatureTransformer {
return FeatureSet::HashValue ^ (OutputDimensions * 2);
}

static constexpr void order_packs([[maybe_unused]] uint64_t* v) {
#if defined(USE_AVX512) // _mm512_packs_epi16 ordering
uint64_t tmp0 = v[2], tmp1 = v[3];
v[2] = v[8], v[3] = v[9];
v[8] = v[4], v[9] = v[5];
v[4] = tmp0, v[5] = tmp1;
tmp0 = v[6], tmp1 = v[7];
v[6] = v[10], v[7] = v[11];
v[10] = v[12], v[11] = v[13];
v[12] = tmp0, v[13] = tmp1;
#elif defined(USE_AVX2) // _mm256_packs_epi16 ordering
std::swap(v[2], v[4]);
std::swap(v[3], v[5]);
#endif
}

static constexpr void inverse_order_packs([[maybe_unused]] uint64_t* v) {
#if defined(USE_AVX512) // Inverse _mm512_packs_epi16 ordering
uint64_t tmp0 = v[2], tmp1 = v[3];
v[2] = v[4], v[3] = v[5];
v[4] = v[8], v[5] = v[9];
v[8] = tmp0, v[9] = tmp1;
tmp0 = v[6], tmp1 = v[7];
v[6] = v[12], v[7] = v[13];
v[12] = v[10], v[13] = v[11];
v[10] = tmp0, v[11] = tmp1;
#elif defined(USE_AVX2) // Inverse _mm256_packs_epi16 ordering
std::swap(v[2], v[4]);
std::swap(v[3], v[5]);
#endif
}

void permute_weights([[maybe_unused]] void (*order_fn)(uint64_t*)) {
#if defined(USE_AVX2)
#if defined(USE_AVX512)
constexpr IndexType di = 16;
#else
constexpr IndexType di = 8;
#endif
uint64_t* b = reinterpret_cast<uint64_t*>(&biases[0]);
for (IndexType i = 0; i < HalfDimensions * sizeof(BiasType) / sizeof(uint64_t); i += di)
order_fn(&b[i]);
template<typename Function>
void permute_weights(Function order_fn) {
for (IndexType i = 0; i < HalfDimensions; i += Packing::num_elements)
order_fn(&biases[i]);

for (IndexType j = 0; j < InputDimensions; ++j)
{
uint64_t* w = reinterpret_cast<uint64_t*>(&weights[j * HalfDimensions]);
for (IndexType i = 0; i < HalfDimensions * sizeof(WeightType) / sizeof(uint64_t);
i += di)
auto* w = &weights[j * HalfDimensions];
for (IndexType i = 0; i < HalfDimensions; i += Packing::num_elements)
order_fn(&w[i]);
}
#endif
}

inline void scale_weights(bool read) {
Expand All @@ -300,22 +327,22 @@ class FeatureTransformer {
read_leb_128<WeightType>(stream, weights, HalfDimensions * InputDimensions);
read_leb_128<PSQTWeightType>(stream, psqtWeights, PSQTBuckets * InputDimensions);

permute_weights(inverse_order_packs);
permute_weights(Packing::permute_for_packus_epi16);
scale_weights(true);
return !stream.fail();
}

// Write network parameters
bool write_parameters(std::ostream& stream) {

permute_weights(order_packs);
permute_weights(Packing::unpermute_for_packus_epi16);
scale_weights(false);

write_leb_128<BiasType>(stream, biases, HalfDimensions);
write_leb_128<WeightType>(stream, weights, HalfDimensions * InputDimensions);
write_leb_128<PSQTWeightType>(stream, psqtWeights, PSQTBuckets * InputDimensions);

permute_weights(inverse_order_packs);
permute_weights(Packing::permute_for_packus_epi16);
scale_weights(true);
return !stream.fail();
}
Expand Down
Loading