#pragma once

/*
* Blake2b AVX2 single round function implementation based on libsodium implementation: 
* https://github.com/jedisct1/libsodium/blob/1.0.16/src/libsodium/crypto_generichash/blake2b/ref/blake2b-compress-avx2.h
*/

#include "avx2.hpp"

namespace modernRX::blake2b {
    // Exception from rule to not use macros.
    // Inlining this function is crucial for performance and i wasnt able to force compiler to inline it just with template function.
    // Making a call for each round severly slows down algorithm(even 10-20x).
    //
    // vshuffleepi32 performs right rotation by 32 bits.
    // vshuffleepi8 performs right rotation by 24 or 16 bits accordingly to mask.
    // vpermuteepi64 performs rotation by 64, 128 or 192 bits.
#define ROUND(round, msg, v1, v2, v3, v4, rot24, rot16) do {                        \
/*G1V1*/                                                                            \
        v1 = intrinsics::avx2::vadd<uint64_t>(v1, m);                               \
        v1 = intrinsics::avx2::vadd<uint64_t>(v1, v2);                              \
        m = msgPermutation<1, round>(msg);                                          \
        v4 = intrinsics::avx2::vxor<uint64_t>(v4, v1);                              \
        v4 = intrinsics::avx2::vshuffleepi32<uint64_t, 0b10'11'00'01>(v4);          \
        v3 = intrinsics::avx2::vadd<uint64_t>(v3, v4);                              \
        v2 = intrinsics::avx2::vxor<uint64_t>(v2, v3);                              \
        v2 = intrinsics::avx2::vshuffleepi8<uint64_t>(v2, rot24);                   \
/*G2V1*/                                                                            \
        v1 = intrinsics::avx2::vadd<uint64_t>(v1, m);                               \
        v1 = intrinsics::avx2::vadd<uint64_t>(v1, v2);                              \
        m = msgPermutation<2, round>(msg);                                          \
        v4 = intrinsics::avx2::vxor<uint64_t>(v4, v1);                              \
        v4 = intrinsics::avx2::vshuffleepi8<uint64_t>(v4, rot16);                   \
        v3 = intrinsics::avx2::vadd<uint64_t>(v3, v4);                              \
        v2 = intrinsics::avx2::vxor<uint64_t>(v2, v3);                              \
        v2 = intrinsics::avx2::vrorpi64<uint64_t, 63>(v2);                          \
/*DIAG_V1*/                                                                         \
        v2 = intrinsics::avx2::vpermuteepi64<uint64_t, 0b00'11'10'01>(v2);          \
        v4 = intrinsics::avx2::vpermuteepi64<uint64_t, 0b10'01'00'11>(v4);          \
        v3 = intrinsics::avx2::vpermuteepi64<uint64_t, 0b01'00'11'10>(v3);          \
/*G1V1*/                                                                            \
        v1 = intrinsics::avx2::vadd<uint64_t>(v1, m);                               \
        v1 = intrinsics::avx2::vadd<uint64_t>(v1, v2);                              \
        m = msgPermutation<3, round>(msg);                                          \
        v4 = intrinsics::avx2::vxor<uint64_t>(v1, v4);                              \
        v4 = intrinsics::avx2::vshuffleepi32<uint64_t, 0b10'11'00'01>(v4);          \
        v3 = intrinsics::avx2::vadd<uint64_t>(v3, v4);                              \
        v2 = intrinsics::avx2::vxor<uint64_t>(v2, v3);                              \
        v2 = intrinsics::avx2::vshuffleepi8<uint64_t>(v2, rot24);                   \
/*G2V1*/                                                                            \
        v1 = intrinsics::avx2::vadd<uint64_t>(v1, m);                               \
        v1 = intrinsics::avx2::vadd<uint64_t>(v1, v2);                              \
        m = msgPermutation<0, (round + 1) % 10>(msg);                               \
        v4 = intrinsics::avx2::vxor<uint64_t>(v4, v1);                              \
        v4 = intrinsics::avx2::vshuffleepi8<uint64_t>(v4, rot16);                   \
        v3 = intrinsics::avx2::vadd<uint64_t>(v3, v4);                              \
        v2 = intrinsics::avx2::vxor<uint64_t>(v2, v3);                              \
        v2 = intrinsics::avx2::vrorpi64<uint64_t, 63>(v2);                          \
/*UNDIAG_V1*/                                                                       \
        v2 = intrinsics::avx2::vpermuteepi64<uint64_t, 0b10'01'00'11>(v2);          \
        v4 = intrinsics::avx2::vpermuteepi64<uint64_t, 0b00'11'10'01>(v4);          \
        v3 = intrinsics::avx2::vpermuteepi64<uint64_t, 0b01'00'11'10>(v3);          \
} while(0);


    // Permute the message words for avx2 enhanced version.
    // Its important that this function was always inlined otherwise it may hurt performance.
    // There's no way to force inlining, so it would be safier to use macro, but it seems that compiler is smart enough to inline it anyway.
    // Additional check for generated binary would also be nice to ensure that this function was in fact inlined.
    template <uint32_t N, uint32_t Round>
    [[nodiscard]] constexpr intrinsics::ymm<uint64_t> msgPermutation(const_span<intrinsics::ymm<uint64_t>, 8> m) noexcept {
        intrinsics::ymm<uint64_t> t0, t1;

        if constexpr (N == 0) {
            if constexpr (Round == 0) {
                t0 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[0], m[1]);
                t1 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[2], m[3]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 1) {
                t0 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[7], m[2]);
                t1 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[4], m[6]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 2) {
                t0 = intrinsics::avx2::valignrepi8<uint64_t, 8>(m[6], m[5]);
                t1 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[2], m[7]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 3) {
                t0 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[3], m[1]);
                t1 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[6], m[5]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 4) {
                t0 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[4], m[2]);
                t1 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[1], m[5]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 5) {
                t0 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[1], m[3]);
                t1 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[0], m[4]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 6) {
                t0 = intrinsics::avx2::vblendepi32<uint64_t, 0b00'11'00'11>(m[0], m[6]);
                t1 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[7], m[2]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 7) {
                t0 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[6], m[3]);
                t1 = intrinsics::avx2::vblendepi32<uint64_t, 0b00'11'00'11>(m[1], m[6]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 8) {
                t0 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[3], m[7]);
                t1 = intrinsics::avx2::valignrepi8<uint64_t, 8>(m[0], m[5]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 9) {
                t0 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[5], m[4]);
                t1 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[3], m[0]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else {
                static_assert(!Round, "Invalid round number");
            }
        } else if constexpr (N == 1) {
            if constexpr (Round == 0) {
                t0 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[0], m[1]);
                t1 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[2], m[3]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 1) {
                t0 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[5], m[4]);
                t1 = intrinsics::avx2::valignrepi8<uint64_t, 8>(m[3], m[7]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 2) {
                t0 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[4], m[0]);
                t1 = intrinsics::avx2::vblendepi32<uint64_t, 0b00'11'00'11>(m[6], m[1]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 3) {
                t0 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[4], m[0]);
                t1 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[6], m[7]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 4) {
                t0 = intrinsics::avx2::vblendepi32<uint64_t, 0b00'11'00'11>(m[3], m[0]);
                t1 = intrinsics::avx2::vblendepi32<uint64_t, 0b00'11'00'11>(m[7], m[2]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 5) {
                t0 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[6], m[5]);
                t1 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[5], m[1]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 6) {
                t0 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[2], m[7]);
                t1 = intrinsics::avx2::valignrepi8<uint64_t, 8>(m[5], m[6]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 7) {
                t0 = intrinsics::avx2::valignrepi8<uint64_t, 8>(m[7], m[5]);
                t1 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[0], m[4]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 8) {
                t0 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[7], m[4]);
                t1 = intrinsics::avx2::valignrepi8<uint64_t, 8>(m[4], m[1]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 9) {
                t0 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[1], m[2]);
                t1 = intrinsics::avx2::vblendepi32<uint64_t, 0b00'11'00'11>(m[2], m[3]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else {
                static_assert(!Round, "Invalid round number");
            }
        } else if constexpr (N == 2) {
            if constexpr (Round == 0) {
                t0 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[4], m[5]);
                t1 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[6], m[7]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 1) {
                t0 = intrinsics::avx2::vpermuteepi64<uint64_t, 0b10'11'00'01>(m[0]);
                t1 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[5], m[2]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 2) {
                t0 = intrinsics::avx2::vblendepi32<uint64_t, 0b00'11'00'11>(m[1], m[5]);
                t1 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[3], m[4]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 3) {
                t0 = intrinsics::avx2::vblendepi32<uint64_t, 0b00'11'00'11>(m[2], m[1]);
                t1 = intrinsics::avx2::vblendepi32<uint64_t, 0b00'11'00'11>(m[7], m[2]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 4) {
                t0 = intrinsics::avx2::vblendepi32<uint64_t, 0b00'11'00'11>(m[5], m[7]);
                t1 = intrinsics::avx2::vblendepi32<uint64_t, 0b00'11'00'11>(m[1], m[3]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 5) {
                t0 = intrinsics::avx2::vblendepi32<uint64_t, 0b00'11'00'11>(m[3], m[2]);
                t1 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[7], m[0]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 6) {
                t0 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[0], m[3]);
                t1 = intrinsics::avx2::vpermuteepi64<uint64_t, 0b10'11'00'01>(m[4]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 7) {
                t0 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[2], m[7]);
                t1 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[4], m[1]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 8) {
                t0 = m[6];
                t1 = intrinsics::avx2::valignrepi8<uint64_t, 8>(m[5], m[0]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 9) {
                t0 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[7], m[4]);
                t1 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[1], m[6]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else {
                static_assert(!Round, "Invalid round number");
            }
        } else if constexpr (N == 3) {
            if constexpr (Round == 0) {
                t0 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[4], m[5]);
                t1 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[6], m[7]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 1) {
                t0 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[6], m[1]);
                t1 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[3], m[1]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 2) {
                t0 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[7], m[3]);
                t1 = intrinsics::avx2::valignrepi8<uint64_t, 8>(m[2], m[0]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 3) {
                t0 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[3], m[5]);
                t1 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[0], m[4]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 4) {
                t0 = intrinsics::avx2::valignrepi8<uint64_t, 8>(m[6], m[0]);
                t1 = intrinsics::avx2::vblendepi32<uint64_t, 0b00'11'00'11>(m[6], m[4]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 5) {
                t0 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[6], m[2]);
                t1 = intrinsics::avx2::vblendepi32<uint64_t, 0b00'11'00'11>(m[4], m[7]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 6) {
                t0 = intrinsics::avx2::vunpackhiepi64<uint64_t>(m[3], m[1]);
                t1 = intrinsics::avx2::vblendepi32<uint64_t, 0b00'11'00'11>(m[5], m[1]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 7) {
                t0 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[0], m[2]);
                t1 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[3], m[5]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 8) {
                t0 = intrinsics::avx2::vblendepi32<uint64_t, 0b00'11'00'11>(m[3], m[1]);
                t1 = m[2];
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else if constexpr (Round == 9) {
                t0 = intrinsics::avx2::valignrepi8<uint64_t, 8>(m[7], m[5]);
                t1 = intrinsics::avx2::vunpackloepi64<uint64_t>(m[6], m[0]);
                return intrinsics::avx2::vblendepi32<uint64_t, 0b11'11'00'00>(t0, t1);
            } else {
                static_assert(!Round, "Invalid round number");
            }
        } else {
            static_assert(!N, "Invalid template parameter N");
        }
    }
}

Generated by OpenCppCoverage (Version: 0.9.9.0)