﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <cstring>
#include <nn/nn_SdkAssert.h>
#include <nn/nn_Common.h>
#include <nn/nn_Log.h>
#include <nn/crypto/crypto_Config.h>
#include <nn/crypto/detail/crypto_AesImpl.h>
#include <nn/crypto/detail/crypto_Clear.h>

#include <arm_neon.h>

namespace nn { namespace crypto { namespace detail {

namespace
{
    /* SubBytes 処理の置換表(鍵拡張で使用する) */
    const Bit8 SubBytesTable[256] =
    {
        0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
        0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
        0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
        0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
        0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
        0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
        0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
        0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
        0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
        0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
        0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
        0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
        0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
        0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
        0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
        0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
    };

    /* 鍵拡張で使用される定数 */
    const Bit8 RoundKeyRcon0[] =
    {
        0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
        0x1B, 0x36, 0x6C, 0xD8, 0xAB, 0x4D, 0x9A, 0x2F,
        0x5E, 0xBC, 0x63, 0xC6, 0x97, 0x35, 0x6A, 0xD4,
        0xB3, 0x7D, 0xFA, 0xEF, 0xC5, 0x91,
    };

#if defined(NN_BUILD_CONFIG_ENDIAN_BIG)
    // ビッグエンディアン
    const int aesWordByte0 = 24UL;
    const int aesWordByte1 = 16UL;
    const int aesWordByte2 = 8UL;
    const int aesWordByte3 = 0UL;
#elif defined(NN_BUILD_CONFIG_ENDIAN_LITTLE)
    // リトルエンディアン
    const int aesWordByte0 = 0UL;
    const int aesWordByte1 = 8UL;
    const int aesWordByte2 = 16UL;
    const int aesWordByte3 = 24UL;
#else
#error unknown NN_BUILD_CONFIG_ENDIAN
#endif

    inline Bit32 Func0(Bit32 x)
    {
        return ( (SubBytesTable[(x >> aesWordByte0) & 0xFFUL] << aesWordByte3)
                 ^ (SubBytesTable[(x >> aesWordByte1) & 0xFFUL] << aesWordByte0)
                 ^ (SubBytesTable[(x >> aesWordByte2) & 0xFFUL] << aesWordByte1)
                 ^ (SubBytesTable[(x >> aesWordByte3) & 0xFFUL] << aesWordByte2) );
    }
    inline Bit32 Func1(Bit32 x)
    {
        return ( (SubBytesTable[(x >> aesWordByte0) & 0xFFUL] << aesWordByte0)
                 ^ (SubBytesTable[(x >> aesWordByte1) & 0xFFUL] << aesWordByte1)
                 ^ (SubBytesTable[(x >> aesWordByte2) & 0xFFUL] << aesWordByte2)
                 ^ (SubBytesTable[(x >> aesWordByte3) & 0xFFUL] << aesWordByte3) );
    }

}   // anonymous namespace


template <size_t KeySize>
AesImpl<KeySize>::~AesImpl() NN_NOEXCEPT
{
    ClearMemory(this, sizeof(*this));
}

template <size_t KeySize>
void AesImpl<KeySize>::Initialize(const void* pKey, size_t keySize, bool isEncryptionKey) NN_NOEXCEPT
{
    NN_SDK_REQUIRES(keySize == KeySize, "invalid key size. keySize(=%d) must be either 16, 24, 32", keySize);

    /* 鍵の拡張 */
    const int KeySizeInWord = keySize / sizeof(Bit32);
    Bit32*    pDst = m_RoundKey;
    Bit32     reg;

    /* 初期鍵をコピー */
    std::memcpy(pDst, pKey, keySize);

    /* 前段の最後のwordを使用 */
    reg = pDst[KeySizeInWord - 1];

    for (int i = KeySizeInWord; i < (RoundCount + 1) * 4; ++i)
    {
        /* SubWord + RotWord + XorRcon */
        if ((i % KeySizeInWord) == 0)
        {
            reg = Func0(reg);
            reg ^= (RoundKeyRcon0[i / KeySizeInWord - 1] << aesWordByte0);
        }
        else if ((KeySizeInWord > 6) && ((i % KeySizeInWord) == 4))
        {
            reg = Func1(reg);
        }

        reg ^= pDst[i - KeySizeInWord];
        pDst[i] = reg;
    }

    /*
     * 復号化処理の最適化のためにあらかじめ初期ラウンド以外の鍵を InvMixColumn しておく
     */
    if (!isEncryptionKey)
    {
        uint8_t* pKey8 = reinterpret_cast<uint8_t*>(m_RoundKey);
        pKey8 += BlockSize;

        for (int i = 1; i < RoundCount; ++i)
        {
            uint8x16_t key = vld1q_u8(pKey8);
            key = vaesimcq_u8(key);
            vst1q_u8(pKey8, key);
            pKey8 += BlockSize;
        }
    }
}

template <size_t KeySize>
void AesImpl<KeySize>::EncryptBlock(void* pDst, size_t dstSize, const void* pSrc, size_t srcSize) const NN_NOEXCEPT
{
    const uint8_t* pKey8 = reinterpret_cast<const uint8_t*>(m_RoundKey);

    /* 入力をベクタに変換 */
    uint8x16_t tmp = vld1q_u8(static_cast<const uint8_t*>(pSrc));

    for (int round = 1; round < RoundCount; ++round)
    {
        /* 前ラウンドの RoundKey による AddRoundKey + SubBytes + ShiftRows */
        tmp = vaeseq_u8(tmp, vld1q_u8(pKey8));
        pKey8 += BlockSize;

        /* MixColumns */
        tmp = vaesmcq_u8(tmp);
    }

    /* 前ラウンドの RoundKey による AddRoundKey + SubBytes + ShiftRows */
    tmp = vaeseq_u8(tmp, vld1q_u8(pKey8));
    pKey8 += BlockSize;

    /* 最終 Round の鍵を AddRoundKey */
    tmp = veorq_u8(tmp, vld1q_u8(pKey8));

    /* ベクタを出力バッファに書き出す */
    vst1q_u8(static_cast<uint8_t*>(pDst), tmp);
}

template <size_t KeySize>
void AesImpl<KeySize>::DecryptBlock(void* pDst, size_t dstSize, const void* pSrc, size_t srcSize) const NN_NOEXCEPT
{
    const uint8_t* pKey8 = reinterpret_cast<const uint8_t*>(m_RoundKey) + (RoundCount * BlockSize);

    /* 入力をベクタに変換 */
    uint8x16_t tmp = vld1q_u8(static_cast<const uint8_t*>(pSrc));

    for (int round = RoundCount; round > 1; --round)
    {
        /* AddRoundKey + InvSubBytes + InvShiftRows */
        tmp = vaesdq_u8(tmp, vld1q_u8(pKey8));
        pKey8 -= BlockSize;

        /* InvMixColumns（次のラウンドの先頭の処理に該当する） */
        tmp = vaesimcq_u8(tmp);
    }

    /* AddRoundKey + InvSubBytes + InvShiftRows */
    tmp = vaesdq_u8(tmp, vld1q_u8(pKey8));
    pKey8 -= BlockSize;

    /* 初期 Round の鍵を AddRoundKey */
    tmp = veorq_u8(tmp, vld1q_u8(pKey8));

    /* ベクタを出力バッファに書き出す */
    vst1q_u8(static_cast<uint8_t*>(pDst), tmp);
}



// ASM specializations for AES-128, slightly more efficient than compiler-generated code (~20%)
// This is achieved by forcing the two instructions of an AES round to be consecutive.
// This optimization works on Cortex-A57, probably not on Cortex-A53, hence the conditional.
#if defined(NN_BUILD_CONFIG_CPU_ARM64) && defined(NN_BUILD_CONFIG_COMPILER_CLANG) && defined(NN_BUILD_CONFIG_CPU_CORTEX_A57)

template <>
void AesImpl<16>::EncryptBlock(void* pDst, size_t dstSize, const void* pSrc, size_t srcSize) const NN_NOEXCEPT
{
    const uint8_t* pKey8 = reinterpret_cast<const uint8_t*>(m_RoundKey);
    uint8x16_t block = vld1q_u8(static_cast<const uint8_t*>(pSrc));

    __asm__ volatile
    (
        "ldr   q1, [%[pKey8]]\n"
        "aese  %[block].16b, v1.16b\n"
        "aesmc %[block].16b, %[block].16b\n"
        "ldr   q1, [%[pKey8], #16]\n"
        "aese  %[block].16b, v1.16b\n"
        "aesmc %[block].16b, %[block].16b\n"
        "ldr   q1, [%[pKey8], #32]\n"
        "aese  %[block].16b, v1.16b\n"
        "aesmc %[block].16b, %[block].16b\n"
        "ldr   q1, [%[pKey8], #48]\n"
        "aese  %[block].16b, v1.16b\n"
        "aesmc %[block].16b, %[block].16b\n"
        "ldr   q1, [%[pKey8], #64]\n"
        "aese  %[block].16b, v1.16b\n"
        "aesmc %[block].16b, %[block].16b\n"
        "ldr   q1, [%[pKey8], #80]\n"
        "aese  %[block].16b, v1.16b\n"
        "aesmc %[block].16b, %[block].16b\n"
        "ldr   q1, [%[pKey8], #96]\n"
        "aese  %[block].16b, v1.16b\n"
        "aesmc %[block].16b, %[block].16b\n"
        "ldr   q1, [%[pKey8], #112]\n"
        "aese  %[block].16b, v1.16b\n"
        "aesmc %[block].16b, %[block].16b\n"
        "ldr   q1, [%[pKey8], #128]\n"
        "aese  %[block].16b, v1.16b\n"
        "aesmc %[block].16b, %[block].16b\n"
        "ldr   q1, [%[pKey8], #144]\n"
        "aese  %[block].16b, v1.16b\n"
        "ldr   q1, [%[pKey8], #160]\n"
        "eor   %[block].16b, %[block].16b, v1.16b\n"
        : [block]"+w"(block)
        : [pKey8]"r"(pKey8)
        : "v1"
    );

    vst1q_u8(static_cast<uint8_t*>(pDst), block);
}

template <>
void AesImpl<16>::DecryptBlock(void* pDst, size_t dstSize, const void* pSrc, size_t srcSize) const NN_NOEXCEPT
{
    const uint8_t* pKey8 = reinterpret_cast<const uint8_t*>(m_RoundKey);
    uint8x16_t block = vld1q_u8(static_cast<const uint8_t*>(pSrc));

    __asm__ volatile
    (
        "ldr   q1, [%[pKey8], #160]\n"
        "aesd   %[block].16b, v1.16b\n"
        "aesimc %[block].16b, %[block].16b\n"
        "ldr   q1, [%[pKey8], #144]\n"
        "aesd   %[block].16b, v1.16b\n"
        "aesimc %[block].16b, %[block].16b\n"
        "ldr   q1, [%[pKey8], #128]\n"
        "aesd   %[block].16b, v1.16b\n"
        "aesimc %[block].16b, %[block].16b\n"
        "ldr   q1, [%[pKey8], #112]\n"
        "aesd   %[block].16b, v1.16b\n"
        "aesimc %[block].16b, %[block].16b\n"
        "ldr   q1, [%[pKey8], #96]\n"
        "aesd   %[block].16b, v1.16b\n"
        "aesimc %[block].16b, %[block].16b\n"
        "ldr   q1, [%[pKey8], #80]\n"
        "aesd   %[block].16b, v1.16b\n"
        "aesimc %[block].16b, %[block].16b\n"
        "ldr   q1, [%[pKey8], #64]\n"
        "aesd   %[block].16b, v1.16b\n"
        "aesimc %[block].16b, %[block].16b\n"
        "ldr   q1, [%[pKey8], #48]\n"
        "aesd   %[block].16b, v1.16b\n"
        "aesimc %[block].16b, %[block].16b\n"
        "ldr   q1, [%[pKey8], #32]\n"
        "aesd   %[block].16b, v1.16b\n"
        "aesimc %[block].16b, %[block].16b\n"
        "ldr   q1, [%[pKey8], #16]\n"
        "aesd   %[block].16b, v1.16b\n"
        "ldr   q1, [%[pKey8]]\n"
        "eor   %[block].16b, %[block].16b, v1.16b\n"
        : [block]"+w"(block)
        : [pKey8]"r"(pKey8)
        : "v1"
    );

    vst1q_u8(static_cast<uint8_t*>(pDst), block);
}

#endif // defined(NN_BUILD_CONFIG_CPU_ARM64) && defined(NN_BUILD_CONFIG_COMPILER_CLANG) && defined(NN_BUILD_CONFIG_CPU_CORTEX_A57)

/* テンプレートの明示的実体化 */
template class AesImpl<16>;
template class AesImpl<24>;
template class AesImpl<32>;

}}}
