﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/crypto/detail/crypto_CtrModeImpl.h>
#include <arm_neon.h>


// Whether to use AArch64 ASM optimizations (~1.4x faster than intrinscs)
#ifndef NN_CRYPTO_DETAIL_CTR_USE_ASM
    #define NN_CRYPTO_DETAIL_CTR_USE_ASM 1 // NOLINT(preprocessor/const)
#endif

namespace nn { namespace crypto { namespace detail {


/**
 * @brief Increments a 128-bit big endian value
 */
static inline uint8x16_t Inc128Be(uint8x16_t val)
{
    // Extract both halves and revert endianness
    uint64x2_t val64 = vreinterpretq_u64_u8(val);
    uint64_t   hi    = vgetq_lane_u64(val64, 0);
    uint64_t   lo    = vgetq_lane_u64(val64, 1);
    hi = __builtin_bswap64(hi);
    lo = __builtin_bswap64(lo);

    // Increment 128-bit value
#ifdef NN_BUILD_CONFIG_CPU_ARM64
    typedef unsigned __int128 uint128_t;
    uint128_t res128 = (uint128_t(hi)<<64) | lo;
    ++res128;
    hi = uint64_t(res128 >> 64);
    lo = uint64_t(res128);
#else
    ++lo;
    hi += (lo == 0);
#endif

    // Re-revert endianness and reconstruct value
    hi = __builtin_bswap64(hi);
    lo = __builtin_bswap64(lo);
    uint64x2_t res64 = vcombine_u64(vmov_n_u64(hi), vmov_n_u64(lo));
    return vreinterpretq_u8_u64(res64);
}


#if NN_CRYPTO_DETAIL_CTR_USE_ASM && defined(NN_BUILD_CONFIG_CPU_ARM64) && defined(NN_BUILD_CONFIG_COMPILER_CLANG)
    #define AES128_RND(num,data) "aese  %[" #data "].16b, %[key" #num "].16b\n" \
                                 "aesmc %[" #data "].16b, %[" #data "].16b\n"
    #define AES128_RND9(data)    "aese  %[" #data "].16b, %[key9].16b\n"
    #define AES128_RND10(data)   "eor   %[" #data "].16b, %[" #data "].16b, %[key10].16b\n"

    // Latencies of AES instructions are masked by processing 3 blocks at a time.
    // Counter increment is done on the scalar ALU, and so runs in parallel of AES (i.e. it's basically free)
    // For increased legibility, the two instructions streams are written side by side.
    // Note that this code is optimized for Cortex-A57 and might run slower on other processors.
    #define CTR_AES128_UNROLLED3()                                                \
        mask1 = ctr1;                                                             \
        mask2 = ctr2;                                                             \
        mask3 = ctr3;                                                             \
        uint64_t hi, hi2, lo, lo2;                                                \
        __asm__ volatile                                                          \
        (                                                                         \
                                           AES128_RND(0,mask1)                    \
            "mov  %[hi], %[ctr3].d[0]\n"   AES128_RND(0,mask2)                    \
            "mov  %[lo], %[ctr3].d[1]\n"   AES128_RND(0,mask3)                    \
            "rev  %[hi], %[hi]\n"          AES128_RND(1,mask1)                    \
            "rev  %[lo], %[lo]\n"          AES128_RND(1,mask2)                    \
            "adds %[lo], %[lo], #1\n"      AES128_RND(1,mask3)                    \
            "adc  %[hi], %[hi], xzr\n"     AES128_RND(2,mask1)                    \
            "rev  %[hi2], %[hi]\n"         AES128_RND(2,mask2)                    \
            "rev  %[lo2], %[lo]\n"         AES128_RND(2,mask3)                    \
            "mov  %[ctr1].d[0], %[hi2]\n"  AES128_RND(3,mask1)                    \
            "mov  %[ctr1].d[1], %[lo2]\n"  AES128_RND(3,mask2)                    \
            "adds %[lo], %[lo], #1\n"      AES128_RND(3,mask3)                    \
            "adc  %[hi], %[hi], xzr\n"     AES128_RND(4,mask1)                    \
            "rev  %[hi2], %[hi]\n"         AES128_RND(4,mask2)                    \
            "rev  %[lo2], %[lo]\n"         AES128_RND(4,mask3)                    \
            "mov  %[ctr2].d[0], %[hi2]\n"  AES128_RND(5,mask1)                    \
            "mov  %[ctr2].d[1], %[lo2]\n"  AES128_RND(5,mask2)                    \
            "adds %[lo], %[lo], #1\n"      AES128_RND(5,mask3)                    \
            "adc  %[hi], %[hi], xzr\n"     AES128_RND(6,mask1)                    \
            "rev  %[hi2], %[hi]\n"         AES128_RND(6,mask2)                    \
            "rev  %[lo2], %[lo]\n"         AES128_RND(6,mask3)                    \
            "mov  %[ctr3].d[0], %[hi2]\n"  AES128_RND(7,mask1)                    \
            "mov  %[ctr3].d[1], %[lo2]\n"  AES128_RND(7,mask2)                    \
                                           AES128_RND(7,mask3)                    \
                                           AES128_RND(8,mask1)                    \
                                           AES128_RND(8,mask2)                    \
                                           AES128_RND(8,mask3)                    \
                                           AES128_RND9( mask1)                    \
                                           AES128_RND9( mask2)                    \
                                           AES128_RND9( mask3)                    \
                                           AES128_RND10(mask1)                    \
                                           AES128_RND10(mask2)                    \
                                           AES128_RND10(mask3)                    \
            : [mask1]"+w"(mask1), [mask2]"+w"(mask2), [mask3]"+w"(mask3),         \
              [ctr1]"+w"(ctr1), [ctr2]"+w"(ctr2), [ctr3]"+w"(ctr3),               \
              [hi]"=&r"(hi), [hi2]"=&r"(hi2), [lo]"=&r"(lo), [lo2]"=&r"(lo2)      \
            : [key0]"w"(key0), [key1]"w"(key1), [key2]"w"(key2), [key3]"w"(key3), \
              [key4]"w"(key4), [key5]"w"(key5), [key6]"w"(key6), [key7]"w"(key7), \
              [key8]"w"(key8), [key9]"w"(key9), [key10]"w"(key10)                 \
            : "cc" /* Condition flags are modified by instructions adds & adc */  \
        );

#else

    // Same as above, but using only intrinsics
    #define CTR_AES128_UNROLLED3()                                                                                                   \
        mask1=vaesmcq_u8(vaeseq_u8(ctr1, key0));  mask2=vaesmcq_u8(vaeseq_u8(ctr2, key0));  mask3=vaesmcq_u8(vaeseq_u8(ctr3, key0)); \
        mask1=vaesmcq_u8(vaeseq_u8(mask1,key1));  mask2=vaesmcq_u8(vaeseq_u8(mask2,key1));  mask3=vaesmcq_u8(vaeseq_u8(mask3,key1)); \
        mask1=vaesmcq_u8(vaeseq_u8(mask1,key2));  mask2=vaesmcq_u8(vaeseq_u8(mask2,key2));  mask3=vaesmcq_u8(vaeseq_u8(mask3,key2)); \
        mask1=vaesmcq_u8(vaeseq_u8(mask1,key3));  mask2=vaesmcq_u8(vaeseq_u8(mask2,key3));  mask3=vaesmcq_u8(vaeseq_u8(mask3,key3)); \
        mask1=vaesmcq_u8(vaeseq_u8(mask1,key4));  mask2=vaesmcq_u8(vaeseq_u8(mask2,key4));  mask3=vaesmcq_u8(vaeseq_u8(mask3,key4)); \
        mask1=vaesmcq_u8(vaeseq_u8(mask1,key5));  mask2=vaesmcq_u8(vaeseq_u8(mask2,key5));  mask3=vaesmcq_u8(vaeseq_u8(mask3,key5)); \
        mask1=vaesmcq_u8(vaeseq_u8(mask1,key6));  mask2=vaesmcq_u8(vaeseq_u8(mask2,key6));  mask3=vaesmcq_u8(vaeseq_u8(mask3,key6)); \
        mask1=vaesmcq_u8(vaeseq_u8(mask1,key7));  mask2=vaesmcq_u8(vaeseq_u8(mask2,key7));  mask3=vaesmcq_u8(vaeseq_u8(mask3,key7)); \
        mask1=vaesmcq_u8(vaeseq_u8(mask1,key8));  mask2=vaesmcq_u8(vaeseq_u8(mask2,key8));  mask3=vaesmcq_u8(vaeseq_u8(mask3,key8)); \
        mask1=vaeseq_u8(mask1,key9);              mask2=vaeseq_u8(mask2,key9);              mask3=vaeseq_u8(mask3,key9);             \
        mask1=veorq_u8(mask1,key10);              mask2=veorq_u8(mask2,key10);              mask3=veorq_u8(mask3,key10);             \
        ctr1 = Inc128Be(ctr3);                                                                                                       \
        ctr2 = Inc128Be(ctr1);                                                                                                       \
        ctr3 = Inc128Be(ctr2);
#endif


template<>
size_t CtrModeImpl<AesEncryptor128>::ProcessBlocksUnrolled(void* pDst, const void* pSrc, size_t size) NN_NOEXCEPT
{
    uint8_t*       pDst8 = reinterpret_cast<uint8_t*>(pDst);
    const uint8_t* pSrc8 = reinterpret_cast<const uint8_t*>(pSrc);

    // Preload round keys
    const uint8_t* keys = m_pBlockCipher->GetRoundKey();
    const uint8x16_t key0  = vld1q_u8(keys);
    const uint8x16_t key1  = vld1q_u8(keys + 16);
    const uint8x16_t key2  = vld1q_u8(keys + 16 * 2);
    const uint8x16_t key3  = vld1q_u8(keys + 16 * 3);
    const uint8x16_t key4  = vld1q_u8(keys + 16 * 4);
    const uint8x16_t key5  = vld1q_u8(keys + 16 * 5);
    const uint8x16_t key6  = vld1q_u8(keys + 16 * 6);
    const uint8x16_t key7  = vld1q_u8(keys + 16 * 7);
    const uint8x16_t key8  = vld1q_u8(keys + 16 * 8);
    const uint8x16_t key9  = vld1q_u8(keys + 16 * 9);
    const uint8x16_t key10 = vld1q_u8(keys + 16 * 10);

    // Unroll on 3 blocks
    uint8x16_t ctr1 = vld1q_u8(m_Counter);
    uint8x16_t ctr2 = Inc128Be(ctr1);
    uint8x16_t ctr3 = Inc128Be(ctr2);
    size_t remaining = size;
    const size_t batchSize = 3 * BlockSize;
    while (remaining >= batchSize)
    {
        uint8x16_t block1 = vld1q_u8(pSrc8); pSrc8 += BlockSize;
        uint8x16_t block2 = vld1q_u8(pSrc8); pSrc8 += BlockSize;
        uint8x16_t block3 = vld1q_u8(pSrc8); pSrc8 += BlockSize;

        uint8x16_t mask1, mask2, mask3;
        CTR_AES128_UNROLLED3()

        block1 = veorq_u8(block1, mask1);
        block2 = veorq_u8(block2, mask2);
        block3 = veorq_u8(block3, mask3);

        vst1q_u8(pDst8, block1); pDst8 += BlockSize;
        vst1q_u8(pDst8, block2); pDst8 += BlockSize;
        vst1q_u8(pDst8, block3); pDst8 += BlockSize;

        remaining -= batchSize;
    }
    vst1q_u8(m_Counter, ctr1);

    return size - remaining;
}


}}} // nn::crypto::detail
