﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/


#include <nn/crypto/detail/crypto_CbcModeImpl.h>
#include <nn/crypto/crypto_AesDecryptor.h>
#include <nn/crypto/crypto_AesEncryptor.h>
#include <arm_neon.h>


// Whether to use AArch64 ASM optimizations (~1.4x faster than intrinscs)
#ifndef NN_CRYPTO_DETAIL_CBC_USE_ASM
    #define NN_CRYPTO_DETAIL_CBC_USE_ASM 1 // NOLINT(preprocessor/const)
#endif

namespace nn { namespace crypto { namespace detail {


void CbcModeAes128Helper::EncryptBlocks(void* pDst, void* pIv, const void* pSrc, int numBlocks, const AesEncryptor128* pEncryptor) NN_NOEXCEPT
{
    const uint8_t* pSrc8 = static_cast<const uint8_t*>(pSrc);
    uint8_t* pDst8       = static_cast<uint8_t*>(pDst);
    uint8_t* pIv8        = static_cast<uint8_t*>(pIv);

    // Preload round keys
    const uint8_t* keys = pEncryptor->GetRoundKey();
    const uint8x16_t key0  = vld1q_u8(keys);
    const uint8x16_t key1  = vld1q_u8(keys + 16);
    const uint8x16_t key2  = vld1q_u8(keys + 16 * 2);
    const uint8x16_t key3  = vld1q_u8(keys + 16 * 3);
    const uint8x16_t key4  = vld1q_u8(keys + 16 * 4);
    const uint8x16_t key5  = vld1q_u8(keys + 16 * 5);
    const uint8x16_t key6  = vld1q_u8(keys + 16 * 6);
    const uint8x16_t key7  = vld1q_u8(keys + 16 * 7);
    const uint8x16_t key8  = vld1q_u8(keys + 16 * 8);
    const uint8x16_t key9  = vld1q_u8(keys + 16 * 9);
    const uint8x16_t key10 = vld1q_u8(keys + 16 * 10);

    // Process blocks
    uint8x16_t mask = vld1q_u8(pIv8);
    while (--numBlocks >= 0)
    {
        uint8x16_t block = vld1q_u8(pSrc8);
        pSrc8 += AesEncryptor128::BlockSize;

        block = veorq_u8(block, mask);

        block = vaesmcq_u8(vaeseq_u8(block, key0));
        block = vaesmcq_u8(vaeseq_u8(block, key1));
        block = vaesmcq_u8(vaeseq_u8(block, key2));
        block = vaesmcq_u8(vaeseq_u8(block, key3));
        block = vaesmcq_u8(vaeseq_u8(block, key4));
        block = vaesmcq_u8(vaeseq_u8(block, key5));
        block = vaesmcq_u8(vaeseq_u8(block, key6));
        block = vaesmcq_u8(vaeseq_u8(block, key7));
        block = vaesmcq_u8(vaeseq_u8(block, key8));
        block = vaeseq_u8(block, key9);
        block = veorq_u8(block, key10);

        mask = block;

        vst1q_u8(pDst8, block);
        pDst8 += AesEncryptor128::BlockSize;
    }

    vst1q_u8(pIv8, mask);
}


void CbcModeAes128Helper::DecryptBlocks(void* pDst, void* pIv, const void* pSrc, int numBlocks, const AesDecryptor128* pDecryptor) NN_NOEXCEPT
{
    const uint8_t* pSrc8 = static_cast<const uint8_t*>(pSrc);
    uint8_t* pDst8       = static_cast<uint8_t*>(pDst);
    uint8_t* pIv8        = static_cast<uint8_t*>(pIv);

    // Preload round keys
    const uint8_t* keys = pDecryptor->GetRoundKey();
    const uint8x16_t key0  = vld1q_u8(keys);
    const uint8x16_t key1  = vld1q_u8(keys + 16);
    const uint8x16_t key2  = vld1q_u8(keys + 16 * 2);
    const uint8x16_t key3  = vld1q_u8(keys + 16 * 3);
    const uint8x16_t key4  = vld1q_u8(keys + 16 * 4);
    const uint8x16_t key5  = vld1q_u8(keys + 16 * 5);
    const uint8x16_t key6  = vld1q_u8(keys + 16 * 6);
    const uint8x16_t key7  = vld1q_u8(keys + 16 * 7);
    const uint8x16_t key8  = vld1q_u8(keys + 16 * 8);
    const uint8x16_t key9  = vld1q_u8(keys + 16 * 9);
    const uint8x16_t key10 = vld1q_u8(keys + 16 * 10);

    uint8x16_t mask = vld1q_u8(pIv8);

    // Process blocks 3 at a time
    const int batchSize = 3;
    while ((numBlocks-=batchSize) >= 0)
    {
        uint8x16_t block1 = vld1q_u8(pSrc8); pSrc8 += AesEncryptor128::BlockSize;
        uint8x16_t block2 = vld1q_u8(pSrc8); pSrc8 += AesEncryptor128::BlockSize;
        uint8x16_t block3 = vld1q_u8(pSrc8); pSrc8 += AesEncryptor128::BlockSize;
        uint8x16_t data1  = block1;
        uint8x16_t data2  = block2;
        uint8x16_t data3  = block3;

#if NN_CRYPTO_DETAIL_CBC_USE_ASM && defined(NN_BUILD_CONFIG_CPU_ARM64) && (defined(NN_BUILD_CONFIG_COMPILER_CLANG) || defined(NN_BUILD_CONFIG_COMPILER_GCC))
        #define AES128D_RND(num,block) "aesd   %[" #block "].16b, %[key" #num "].16b\n" \
                                       "aesimc %[" #block "].16b, %[" #block "].16b\n"
        #define AES128D_RND1(block)    "aesd   %[" #block "].16b, %[key1].16b\n"
        #define AES128D_RND0(block)    "eor    %[" #block "].16b, %[" #block "].16b, %[key0].16b\n"

        // Latencies of AES instructions are masked by processing 3 blocks at a time.
        // Note that this code is optimized for Cortex-A57 and might run slower on other processors.
        __asm__ volatile
        (
            AES128D_RND(10,data1)  AES128D_RND(10,data2)  AES128D_RND(10,data3)
            AES128D_RND(9,data1)   AES128D_RND(9,data2)   AES128D_RND(9,data3)
            AES128D_RND(8,data1)   AES128D_RND(8,data2)   AES128D_RND(8,data3)
            AES128D_RND(7,data1)   AES128D_RND(7,data2)   AES128D_RND(7,data3)
            AES128D_RND(6,data1)   AES128D_RND(6,data2)   AES128D_RND(6,data3)
            AES128D_RND(5,data1)   AES128D_RND(5,data2)   AES128D_RND(5,data3)
            AES128D_RND(4,data1)   AES128D_RND(4,data2)   AES128D_RND(4,data3)
            AES128D_RND(3,data1)   AES128D_RND(3,data2)   AES128D_RND(3,data3)
            AES128D_RND(2,data1)   AES128D_RND(2,data2)   AES128D_RND(2,data3)
            AES128D_RND1( data1)   AES128D_RND1( data2)   AES128D_RND1( data3)
            AES128D_RND0( data1)   AES128D_RND0( data2)   AES128D_RND0( data3)
            : [data1]"+w"(data1), [data2]"+w"(data2), [data3]"+w"(data3)
            : [key0]"w"(key0), [key1]"w"(key1), [key2]"w"(key2), [key3]"w"(key3),
              [key4]"w"(key4), [key5]"w"(key5), [key6]"w"(key6), [key7]"w"(key7),
              [key8]"w"(key8), [key9]"w"(key9), [key10]"w"(key10)
            : // Empty clobber list
        );
#else
        data1=vaesimcq_u8(vaesdq_u8(data1,key10));  data2=vaesimcq_u8(vaesdq_u8(data2,key10));  data3=vaesimcq_u8(vaesdq_u8(data3,key10));
        data1=vaesimcq_u8(vaesdq_u8(data1,key9));   data2=vaesimcq_u8(vaesdq_u8(data2,key9));   data3=vaesimcq_u8(vaesdq_u8(data3,key9));
        data1=vaesimcq_u8(vaesdq_u8(data1,key8));   data2=vaesimcq_u8(vaesdq_u8(data2,key8));   data3=vaesimcq_u8(vaesdq_u8(data3,key8));
        data1=vaesimcq_u8(vaesdq_u8(data1,key7));   data2=vaesimcq_u8(vaesdq_u8(data2,key7));   data3=vaesimcq_u8(vaesdq_u8(data3,key7));
        data1=vaesimcq_u8(vaesdq_u8(data1,key6));   data2=vaesimcq_u8(vaesdq_u8(data2,key6));   data3=vaesimcq_u8(vaesdq_u8(data3,key6));
        data1=vaesimcq_u8(vaesdq_u8(data1,key5));   data2=vaesimcq_u8(vaesdq_u8(data2,key5));   data3=vaesimcq_u8(vaesdq_u8(data3,key5));
        data1=vaesimcq_u8(vaesdq_u8(data1,key4));   data2=vaesimcq_u8(vaesdq_u8(data2,key4));   data3=vaesimcq_u8(vaesdq_u8(data3,key4));
        data1=vaesimcq_u8(vaesdq_u8(data1,key3));   data2=vaesimcq_u8(vaesdq_u8(data2,key3));   data3=vaesimcq_u8(vaesdq_u8(data3,key3));
        data1=vaesimcq_u8(vaesdq_u8(data1,key2));   data2=vaesimcq_u8(vaesdq_u8(data2,key2));   data3=vaesimcq_u8(vaesdq_u8(data3,key2));
        data1=vaesdq_u8(data1,key1);                data2=vaesdq_u8(data2,key1);                data3=vaesdq_u8(data3,key1);
        data1=veorq_u8(data1,key0);                 data2=veorq_u8(data2,key0);                 data3=veorq_u8(data3,key0);
#endif

        data1 = veorq_u8(data1, mask);
        data2 = veorq_u8(data2, block1);
        data3 = veorq_u8(data3, block2);
        mask  = block3;

        vst1q_u8(pDst8, data1); pDst8 += AesEncryptor128::BlockSize;
        vst1q_u8(pDst8, data2); pDst8 += AesEncryptor128::BlockSize;
        vst1q_u8(pDst8, data3); pDst8 += AesEncryptor128::BlockSize;
    }
    numBlocks += batchSize;

    // Process blocks one by one
    while (--numBlocks >= 0)
    {
        uint8x16_t block = vld1q_u8(pSrc8);
        pSrc8 += AesEncryptor128::BlockSize;

        uint8x16_t data;
        data = vaesimcq_u8(vaesdq_u8(block, key10));
        data = vaesimcq_u8(vaesdq_u8(data,  key9));
        data = vaesimcq_u8(vaesdq_u8(data,  key8));
        data = vaesimcq_u8(vaesdq_u8(data,  key7));
        data = vaesimcq_u8(vaesdq_u8(data,  key6));
        data = vaesimcq_u8(vaesdq_u8(data,  key5));
        data = vaesimcq_u8(vaesdq_u8(data,  key4));
        data = vaesimcq_u8(vaesdq_u8(data,  key3));
        data = vaesimcq_u8(vaesdq_u8(data,  key2));
        data = vaesdq_u8(data, key1);
        data = veorq_u8(data, key0);

        data = veorq_u8(data, mask);
        mask = block;

        vst1q_u8(pDst8, data);
        pDst8 += AesEncryptor128::BlockSize;
    }

    vst1q_u8(pIv8, mask);
}


}}} // nn::crypto::detail
