#include <array>
#include <bit>
#include <format>
#include "aes.hpp"
#include "aes1rhash.hpp"
#include "assertume.hpp"
#include "cast.hpp"
#include "randomxparams.hpp"
#include "sse.hpp"
namespace modernRX::aes {
template void hash1R<true>(std::span<std::byte, 64> output, const_span<std::byte> input) noexcept;
template void hash1R<false>(std::span<std::byte, 64> output, const_span<std::byte> input) noexcept;
void hashAndFill1R(std::span<std::byte, 64> hash, std::span<std::byte, 64> seed, std::span<std::byte> scratchpad) noexcept {
// state0, state1, state2, state3 = Blake2b-512("RandomX AesHash1R state")
// state0 = 0d 2c b5 92 de 56 a8 9f 47 db 82 cc ad 3a 98 d7
// state1 = 6e 99 8d 33 98 b7 c7 15 5a 12 9e f5 57 80 e7 ac
// state2 = 17 00 77 6a d0 c7 62 ae 6b 50 79 50 e4 7c a0 e8
// state3 = 0c 24 0a 63 8d 82 ad 07 05 00 a1 79 48 49 99 7e
auto alignas(16) hash_state0{ intrinsics::fromChars(0x0d, 0x2c, 0xb5, 0x92, 0xde, 0x56, 0xa8, 0x9f, 0x47, 0xdb, 0x82, 0xcc, 0xad, 0x3a, 0x98, 0xd7) };
auto alignas(16) hash_state1{ intrinsics::fromChars(0x6e, 0x99, 0x8d, 0x33, 0x98, 0xb7, 0xc7, 0x15, 0x5a, 0x12, 0x9e, 0xf5, 0x57, 0x80, 0xe7, 0xac) };
auto alignas(16) hash_state2{ intrinsics::fromChars(0x17, 0x00, 0x77, 0x6a, 0xd0, 0xc7, 0x62, 0xae, 0x6b, 0x50, 0x79, 0x50, 0xe4, 0x7c, 0xa0, 0xe8) };
auto alignas(16) hash_state3{ intrinsics::fromChars(0x0c, 0x24, 0x0a, 0x63, 0x8d, 0x82, 0xad, 0x07, 0x05, 0x00, 0xa1, 0x79, 0x48, 0x49, 0x99, 0x7e) };
intrinsics::xmm128i_t alignas(16) seed0{ *reinterpret_cast<intrinsics::xmm128i_t*>(seed.data()) };
intrinsics::xmm128i_t alignas(16) seed1{ *reinterpret_cast<intrinsics::xmm128i_t*>(seed.data() + 16) };
intrinsics::xmm128i_t alignas(16) seed2{ *reinterpret_cast<intrinsics::xmm128i_t*>(seed.data() + 32) };
intrinsics::xmm128i_t alignas(16) seed3{ *reinterpret_cast<intrinsics::xmm128i_t*>(seed.data() + 48) };
// key0, key1, key2, key3 = Blake2b-512("RandomX AesGenerator1R keys")
// key0 = 53 a5 ac 6d 09 66 71 62 2b 55 b5 db 17 49 f4 b4
// key1 = 07 af 7c 6d 0d 71 6a 84 78 d3 25 17 4e dc a1 0d
// key2 = f1 62 12 3f c6 7e 94 9f 4f 79 c0 f4 45 e3 20 3e
// key3 = 35 81 ef 6a 7c 31 ba b1 88 4c 31 16 54 91 16 49
constexpr auto alignas(16) key0{ intrinsics::fromChars(0x53, 0xa5, 0xac, 0x6d, 0x09, 0x66, 0x71, 0x62, 0x2b, 0x55, 0xb5, 0xdb, 0x17, 0x49, 0xf4, 0xb4) };
constexpr auto alignas(16) key1{ intrinsics::fromChars(0x07, 0xaf, 0x7c, 0x6d, 0x0d, 0x71, 0x6a, 0x84, 0x78, 0xd3, 0x25, 0x17, 0x4e, 0xdc, 0xa1, 0x0d) };
constexpr auto alignas(16) key2{ intrinsics::fromChars(0xf1, 0x62, 0x12, 0x3f, 0xc6, 0x7e, 0x94, 0x9f, 0x4f, 0x79, 0xc0, 0xf4, 0x45, 0xe3, 0x20, 0x3e) };
constexpr auto alignas(16) key3{ intrinsics::fromChars(0x35, 0x81, 0xef, 0x6a, 0x7c, 0x31, 0xba, 0xb1, 0x88, 0x4c, 0x31, 0x16, 0x54, 0x91, 0x16, 0x49) };
const auto sp_ptr{ reinterpret_cast<uintptr_t>(scratchpad.data()) };
//process 64 bytes at a time in 4 lanes
for (uint64_t i = 0; i < Rx_Scratchpad_L3_Size; i += 64) {
intrinsics::xmm128i_t& scratchpad0{ *reinterpret_cast<intrinsics::xmm128i_t*>(scratchpad.data() + i) };
intrinsics::xmm128i_t& scratchpad1{ *reinterpret_cast<intrinsics::xmm128i_t*>(scratchpad.data() + i + 16) };
intrinsics::xmm128i_t& scratchpad2{ *reinterpret_cast<intrinsics::xmm128i_t*>(scratchpad.data() + i + 32) };
intrinsics::xmm128i_t& scratchpad3{ *reinterpret_cast<intrinsics::xmm128i_t*>(scratchpad.data() + i + 48) };
intrinsics::aes::encode(hash_state0, scratchpad0);
intrinsics::aes::decode(hash_state1, scratchpad1);
intrinsics::aes::encode(hash_state2, scratchpad2);
intrinsics::aes::decode(hash_state3, scratchpad3);
intrinsics::aes::decode(seed0, key0);
intrinsics::aes::encode(seed1, key1);
intrinsics::aes::decode(seed2, key2);
intrinsics::aes::encode(seed3, key3);
scratchpad0 = seed0;
scratchpad1 = seed1;
scratchpad2 = seed2;
scratchpad3 = seed3;
constexpr auto Rf_Size{ 256 }; // TODO: this shouldnt be here.
intrinsics::prefetch<intrinsics::PrefetchMode::T0, 1>(reinterpret_cast<const void*>(sp_ptr - Rf_Size + ((i + Rf_Size + Rx_Scratchpad_L1_Size) % (Rx_Scratchpad_L3_Size + Rf_Size))));
}
// xkey0, xkey1 = Blake2b-256("RandomX AesHash1R xkeys")
// xkey0 = 89 83 fa f6 9f 94 24 8b bf 56 dc 90 01 02 89 06
// xkey1 = d1 63 b2 61 3c e0 f4 51 c6 43 10 ee 9b f9 18 ed
constexpr auto alignas(16) xkey0{ intrinsics::fromChars(0x89, 0x83, 0xfa, 0xf6, 0x9f, 0x94, 0x24, 0x8b, 0xbf, 0x56, 0xdc, 0x90, 0x01, 0x02, 0x89, 0x06) };
constexpr auto alignas(16) xkey1{ intrinsics::fromChars(0xd1, 0x63, 0xb2, 0x61, 0x3c, 0xe0, 0xf4, 0x51, 0xc6, 0x43, 0x10, 0xee, 0x9b, 0xf9, 0x18, 0xed) };
intrinsics::aes::encode(hash_state0, xkey0);
intrinsics::aes::decode(hash_state1, xkey0);
intrinsics::aes::encode(hash_state2, xkey0);
intrinsics::aes::decode(hash_state3, xkey0);
intrinsics::aes::encode(hash_state0, xkey1);
intrinsics::aes::decode(hash_state1, xkey1);
intrinsics::aes::encode(hash_state2, xkey1);
intrinsics::aes::decode(hash_state3, xkey1);
intrinsics::xmm128i_t& output0{ *reinterpret_cast<intrinsics::xmm128i_t*>(hash.data()) };
intrinsics::xmm128i_t& output1{ *reinterpret_cast<intrinsics::xmm128i_t*>(hash.data() + 16) };
intrinsics::xmm128i_t& output2{ *reinterpret_cast<intrinsics::xmm128i_t*>(hash.data() + 32) };
intrinsics::xmm128i_t& output3{ *reinterpret_cast<intrinsics::xmm128i_t*>(hash.data() + 48) };
output0 = hash_state0;
output1 = hash_state1;
output2 = hash_state2;
output3 = hash_state3;
intrinsics::xmm128i_t& output_seed0{ *reinterpret_cast<intrinsics::xmm128i_t*>(seed.data()) };
intrinsics::xmm128i_t& output_seed1{ *reinterpret_cast<intrinsics::xmm128i_t*>(seed.data() + 16) };
intrinsics::xmm128i_t& output_seed2{ *reinterpret_cast<intrinsics::xmm128i_t*>(seed.data() + 32) };
intrinsics::xmm128i_t& output_seed3{ *reinterpret_cast<intrinsics::xmm128i_t*>(seed.data() + 48) };
output_seed0 = seed0;
output_seed1 = seed1;
output_seed2 = seed2;
output_seed3 = seed3;
}
template<bool Fixed>
void hash1R(std::span<std::byte, 64> output, const_span<std::byte> input) noexcept {
ASSERTUME(input.size() > 0 && input.size() % 64 == 0);
// state0, state1, state2, state3 = Blake2b-512("RandomX AesHash1R state")
// state0 = 0d 2c b5 92 de 56 a8 9f 47 db 82 cc ad 3a 98 d7
// state1 = 6e 99 8d 33 98 b7 c7 15 5a 12 9e f5 57 80 e7 ac
// state2 = 17 00 77 6a d0 c7 62 ae 6b 50 79 50 e4 7c a0 e8
// state3 = 0c 24 0a 63 8d 82 ad 07 05 00 a1 79 48 49 99 7e
auto state0{ intrinsics::fromChars(0x0d, 0x2c, 0xb5, 0x92, 0xde, 0x56, 0xa8, 0x9f, 0x47, 0xdb, 0x82, 0xcc, 0xad, 0x3a, 0x98, 0xd7) };
auto state1{ intrinsics::fromChars(0x6e, 0x99, 0x8d, 0x33, 0x98, 0xb7, 0xc7, 0x15, 0x5a, 0x12, 0x9e, 0xf5, 0x57, 0x80, 0xe7, 0xac) };
auto state2{ intrinsics::fromChars(0x17, 0x00, 0x77, 0x6a, 0xd0, 0xc7, 0x62, 0xae, 0x6b, 0x50, 0x79, 0x50, 0xe4, 0x7c, 0xa0, 0xe8) };
auto state3{ intrinsics::fromChars(0x0c, 0x24, 0x0a, 0x63, 0x8d, 0x82, 0xad, 0x07, 0x05, 0x00, 0xa1, 0x79, 0x48, 0x49, 0x99, 0x7e) };
// Switch between fixed and variable output size.
for (uint64_t i = 0; i < (Fixed ? Rx_Scratchpad_L3_Size : input.size()); i += 64) {
const intrinsics::xmm128i_t& input0{ *reinterpret_cast<const intrinsics::xmm128i_t*>(input.data() + i) };
const intrinsics::xmm128i_t& input1{ *reinterpret_cast<const intrinsics::xmm128i_t*>(input.data() + i + 16) };
const intrinsics::xmm128i_t& input2{ *reinterpret_cast<const intrinsics::xmm128i_t*>(input.data() + i + 32) };
const intrinsics::xmm128i_t& input3{ *reinterpret_cast<const intrinsics::xmm128i_t*>(input.data() + i + 48) };
intrinsics::aes::encode(state0, input0);
intrinsics::aes::decode(state1, input1);
intrinsics::aes::encode(state2, input2);
intrinsics::aes::decode(state3, input3);
}
// xkey0, xkey1 = Blake2b-256("RandomX AesHash1R xkeys")
// xkey0 = 89 83 fa f6 9f 94 24 8b bf 56 dc 90 01 02 89 06
// xkey1 = d1 63 b2 61 3c e0 f4 51 c6 43 10 ee 9b f9 18 ed
constexpr auto key0{ intrinsics::fromChars(0x89, 0x83, 0xfa, 0xf6, 0x9f, 0x94, 0x24, 0x8b, 0xbf, 0x56, 0xdc, 0x90, 0x01, 0x02, 0x89, 0x06) };
constexpr auto key1{ intrinsics::fromChars(0xd1, 0x63, 0xb2, 0x61, 0x3c, 0xe0, 0xf4, 0x51, 0xc6, 0x43, 0x10, 0xee, 0x9b, 0xf9, 0x18, 0xed) };
intrinsics::aes::encode(state0, key0);
intrinsics::aes::decode(state1, key0);
intrinsics::aes::encode(state2, key0);
intrinsics::aes::decode(state3, key0);
intrinsics::aes::encode(state0, key1);
intrinsics::aes::decode(state1, key1);
intrinsics::aes::encode(state2, key1);
intrinsics::aes::decode(state3, key1);
intrinsics::xmm128i_t& output0{ *reinterpret_cast<intrinsics::xmm128i_t*>(output.data()) };
intrinsics::xmm128i_t& output1{ *reinterpret_cast<intrinsics::xmm128i_t*>(output.data() + 16) };
intrinsics::xmm128i_t& output2{ *reinterpret_cast<intrinsics::xmm128i_t*>(output.data() + 32) };
intrinsics::xmm128i_t& output3{ *reinterpret_cast<intrinsics::xmm128i_t*>(output.data() + 48) };
output0 = state0;
output1 = state1;
output2 = state2;
output3 = state3;
}
}