﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/crypto/detail/crypto_CtrModeImpl.h>
#include <nn/util/util_BitUtil.h>
#include "crypto_AesImpl-cpu.x86x64.h"

namespace nn { namespace crypto { namespace detail {

template<>
size_t CtrModeImpl<AesEncryptor128>::ProcessBlocksUnrolled(void* pDst, const void* pSrc, size_t size) NN_NOEXCEPT
{
    NN_STATIC_ASSERT(BlockSize == sizeof(__m128i));
    NN_STATIC_ASSERT((BlockSize & (BlockSize - 1)) == 0);
    NN_STATIC_ASSERT(sizeof(m_EncryptedCounter) == sizeof(__m128i));
    NN_SDK_REQUIRES_NOT_NULL(pSrc);
    NN_SDK_REQUIRES_NOT_NULL(pDst);

    // AesImpl::EncryptBlock() を展開して IncrementCounter() を同時に行う
    if( (BlockSize <= size) && g_IsAesNiAvailable )
    {
        // TODO: __m128i の定義に依存する初期化方法なので注意
        const __m128i One = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 };

        const auto BlockCount = 6; // NOTE: x86 は 6 が最速。x64 は 8 が最速だが共通にする
        const auto CounterThreshold = static_cast<uint8_t>(0x100 - BlockCount);

        const auto processSize = util::align_down(size, BlockSize);
        const auto pKey = m_pBlockCipher->GetRoundKey();

        const __m128i keys[11] =
        {
            _mm_loadu_si128(reinterpret_cast<const __m128i*>(pKey + BlockSize * 0)),
            _mm_loadu_si128(reinterpret_cast<const __m128i*>(pKey + BlockSize * 1)),
            _mm_loadu_si128(reinterpret_cast<const __m128i*>(pKey + BlockSize * 2)),
            _mm_loadu_si128(reinterpret_cast<const __m128i*>(pKey + BlockSize * 3)),
            _mm_loadu_si128(reinterpret_cast<const __m128i*>(pKey + BlockSize * 4)),
            _mm_loadu_si128(reinterpret_cast<const __m128i*>(pKey + BlockSize * 5)),
            _mm_loadu_si128(reinterpret_cast<const __m128i*>(pKey + BlockSize * 6)),
            _mm_loadu_si128(reinterpret_cast<const __m128i*>(pKey + BlockSize * 7)),
            _mm_loadu_si128(reinterpret_cast<const __m128i*>(pKey + BlockSize * 8)),
            _mm_loadu_si128(reinterpret_cast<const __m128i*>(pKey + BlockSize * 9)),
            _mm_loadu_si128(reinterpret_cast<const __m128i*>(pKey + BlockSize * 10))
        };
        NN_STATIC_ASSERT(sizeof(keys) == AesEncryptor128::RoundKeySize);

        auto counter = _mm_loadu_si128(reinterpret_cast<const __m128i*>(m_Counter));

        // BlockCount ブロックまとめて処理
        size_t currentSize = 0;
        for( ; currentSize + BlockSize * BlockCount <= processSize; currentSize += BlockSize * BlockCount )
        {
            // NOTE: 配列では最適化されないので個別に用意する
            //       BlockCount に合わせて reg 変数を増減させて処理を調整する
            __m128i reg0;
            __m128i reg1;
            __m128i reg2;
            __m128i reg3;
            __m128i reg4;
            __m128i reg5;
            __m128i key = keys[0];

            // カウンタの桁上りがある場合
            if( CounterThreshold <= counter.m128i_u8[15] )
            {
                _mm_storeu_si128(reinterpret_cast<__m128i*>(m_Counter), counter);

                reg0 = _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i*>(m_Counter)), key);
                IncrementCounter();
                reg1 = _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i*>(m_Counter)), key);
                IncrementCounter();
                reg2 = _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i*>(m_Counter)), key);
                IncrementCounter();
                reg3 = _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i*>(m_Counter)), key);
                IncrementCounter();
                reg4 = _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i*>(m_Counter)), key);
                IncrementCounter();
                reg5 = _mm_xor_si128(_mm_loadu_si128(reinterpret_cast<const __m128i*>(m_Counter)), key);
                IncrementCounter();

                counter = _mm_loadu_si128(reinterpret_cast<const __m128i*>(m_Counter));
            }
            // カウンタの桁上りがない場合
            else
            {
                reg0 = _mm_xor_si128(counter, key);
                counter = _mm_add_epi64(counter, One);
                reg1 = _mm_xor_si128(counter, key);
                counter = _mm_add_epi64(counter, One);
                reg2 = _mm_xor_si128(counter, key);
                counter = _mm_add_epi64(counter, One);
                reg3 = _mm_xor_si128(counter, key);
                counter = _mm_add_epi64(counter, One);
                reg4 = _mm_xor_si128(counter, key);
                counter = _mm_add_epi64(counter, One);
                reg5 = _mm_xor_si128(counter, key);
                counter = _mm_add_epi64(counter, One);
            }

            // NOTE: for 文では最適化されないので展開しておく
            key = keys[1];
            reg0 = _mm_aesenc_si128(reg0, key);
            reg1 = _mm_aesenc_si128(reg1, key);
            reg2 = _mm_aesenc_si128(reg2, key);
            reg3 = _mm_aesenc_si128(reg3, key);
            reg4 = _mm_aesenc_si128(reg4, key);
            reg5 = _mm_aesenc_si128(reg5, key);

            key = keys[2];
            reg0 = _mm_aesenc_si128(reg0, key);
            reg1 = _mm_aesenc_si128(reg1, key);
            reg2 = _mm_aesenc_si128(reg2, key);
            reg3 = _mm_aesenc_si128(reg3, key);
            reg4 = _mm_aesenc_si128(reg4, key);
            reg5 = _mm_aesenc_si128(reg5, key);

            key = keys[3];
            reg0 = _mm_aesenc_si128(reg0, key);
            reg1 = _mm_aesenc_si128(reg1, key);
            reg2 = _mm_aesenc_si128(reg2, key);
            reg3 = _mm_aesenc_si128(reg3, key);
            reg4 = _mm_aesenc_si128(reg4, key);
            reg5 = _mm_aesenc_si128(reg5, key);

            key = keys[4];
            reg0 = _mm_aesenc_si128(reg0, key);
            reg1 = _mm_aesenc_si128(reg1, key);
            reg2 = _mm_aesenc_si128(reg2, key);
            reg3 = _mm_aesenc_si128(reg3, key);
            reg4 = _mm_aesenc_si128(reg4, key);
            reg5 = _mm_aesenc_si128(reg5, key);

            key = keys[5];
            reg0 = _mm_aesenc_si128(reg0, key);
            reg1 = _mm_aesenc_si128(reg1, key);
            reg2 = _mm_aesenc_si128(reg2, key);
            reg3 = _mm_aesenc_si128(reg3, key);
            reg4 = _mm_aesenc_si128(reg4, key);
            reg5 = _mm_aesenc_si128(reg5, key);

            key = keys[6];
            reg0 = _mm_aesenc_si128(reg0, key);
            reg1 = _mm_aesenc_si128(reg1, key);
            reg2 = _mm_aesenc_si128(reg2, key);
            reg3 = _mm_aesenc_si128(reg3, key);
            reg4 = _mm_aesenc_si128(reg4, key);
            reg5 = _mm_aesenc_si128(reg5, key);

            key = keys[7];
            reg0 = _mm_aesenc_si128(reg0, key);
            reg1 = _mm_aesenc_si128(reg1, key);
            reg2 = _mm_aesenc_si128(reg2, key);
            reg3 = _mm_aesenc_si128(reg3, key);
            reg4 = _mm_aesenc_si128(reg4, key);
            reg5 = _mm_aesenc_si128(reg5, key);

            key = keys[8];
            reg0 = _mm_aesenc_si128(reg0, key);
            reg1 = _mm_aesenc_si128(reg1, key);
            reg2 = _mm_aesenc_si128(reg2, key);
            reg3 = _mm_aesenc_si128(reg3, key);
            reg4 = _mm_aesenc_si128(reg4, key);
            reg5 = _mm_aesenc_si128(reg5, key);

            key = keys[9];
            reg0 = _mm_aesenc_si128(reg0, key);
            reg1 = _mm_aesenc_si128(reg1, key);
            reg2 = _mm_aesenc_si128(reg2, key);
            reg3 = _mm_aesenc_si128(reg3, key);
            reg4 = _mm_aesenc_si128(reg4, key);
            reg5 = _mm_aesenc_si128(reg5, key);

            key = keys[10];
            reg0 = _mm_aesenclast_si128(reg0, key);
            reg1 = _mm_aesenclast_si128(reg1, key);
            reg2 = _mm_aesenclast_si128(reg2, key);
            reg3 = _mm_aesenclast_si128(reg3, key);
            reg4 = _mm_aesenclast_si128(reg4, key);
            reg5 = _mm_aesenclast_si128(reg5, key);

            const auto pSrc8 = reinterpret_cast<const char*>(pSrc) + currentSize;
            const auto pDst8 = reinterpret_cast<char*>(pDst) + currentSize;

            __m128i src;
            src = _mm_loadu_si128(reinterpret_cast<const __m128i*>(pSrc8 + BlockSize * 0));
            _mm_storeu_si128(reinterpret_cast<__m128i*>(pDst8 + BlockSize * 0), _mm_xor_si128(src, reg0));

            src = _mm_loadu_si128(reinterpret_cast<const __m128i*>(pSrc8 + BlockSize * 1));
            _mm_storeu_si128(reinterpret_cast<__m128i*>(pDst8 + BlockSize * 1), _mm_xor_si128(src, reg1));

            src = _mm_loadu_si128(reinterpret_cast<const __m128i*>(pSrc8 + BlockSize * 2));
            _mm_storeu_si128(reinterpret_cast<__m128i*>(pDst8 + BlockSize * 2), _mm_xor_si128(src, reg2));

            src = _mm_loadu_si128(reinterpret_cast<const __m128i*>(pSrc8 + BlockSize * 3));
            _mm_storeu_si128(reinterpret_cast<__m128i*>(pDst8 + BlockSize * 3), _mm_xor_si128(src, reg3));

            src = _mm_loadu_si128(reinterpret_cast<const __m128i*>(pSrc8 + BlockSize * 4));
            _mm_storeu_si128(reinterpret_cast<__m128i*>(pDst8 + BlockSize * 4), _mm_xor_si128(src, reg4));

            src = _mm_loadu_si128(reinterpret_cast<const __m128i*>(pSrc8 + BlockSize * 5));
            _mm_storeu_si128(reinterpret_cast<__m128i*>(pDst8 + BlockSize * 5), _mm_xor_si128(src, reg5));
        }

        _mm_storeu_si128(reinterpret_cast<__m128i*>(m_Counter), counter);

        // 1 ブロックずつ処理
        for( ; currentSize < processSize; currentSize += BlockSize )
        {
            auto reg = _mm_loadu_si128(reinterpret_cast<const __m128i*>(m_Counter));

            reg = _mm_xor_si128(reg, keys[0]);
            reg = _mm_aesenc_si128(reg, keys[1]);
            reg = _mm_aesenc_si128(reg, keys[2]);
            reg = _mm_aesenc_si128(reg, keys[3]);
            reg = _mm_aesenc_si128(reg, keys[4]);
            reg = _mm_aesenc_si128(reg, keys[5]);
            reg = _mm_aesenc_si128(reg, keys[6]);
            reg = _mm_aesenc_si128(reg, keys[7]);
            reg = _mm_aesenc_si128(reg, keys[8]);
            reg = _mm_aesenc_si128(reg, keys[9]);
            reg = _mm_aesenclast_si128(reg, keys[10]);

            const auto pSrc8 = reinterpret_cast<const char*>(pSrc) + currentSize;
            const auto pDst8 = reinterpret_cast<char*>(pDst) + currentSize;

            const auto src = _mm_loadu_si128(reinterpret_cast<const __m128i*>(pSrc8));
            _mm_storeu_si128(reinterpret_cast<__m128i*>(pDst8), _mm_xor_si128(src, reg));

            IncrementCounter();
        }

        return currentSize;
    }
    return 0;
} // NOLINT(impl/function_size)

}}}
