﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <cstring>
#include <algorithm>
#include <nn/nn_SdkAssert.h>
#include <nn/nn_Common.h>
#include <nn/crypto/crypto_Config.h>
#include <nn/crypto/crypto_Sha256Generator.h>
#include <nn/crypto/detail/crypto_Sha256Impl.h>
#include <nn/crypto/detail/crypto_Clear.h>

#include "crypto_Util.h"
#include <arm_neon.h>

namespace nn { namespace crypto { namespace detail {

namespace
{
    /* 各ラウンドの処理で使用する定数 */
    NN_ALIGNAS(64) const uint32_t RoundConstants[64] = // Aligned on a cache line
    {
        0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5,
        0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
        0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3,
        0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
        0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc,
        0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
        0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
        0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
        0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13,
        0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
        0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3,
        0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
        0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5,
        0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
        0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208,
        0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
    };

}   // anonymous namespace


Sha256Impl::~Sha256Impl() NN_NOEXCEPT
{
    ClearMemory(this, sizeof(*this));
}

void Sha256Impl::Initialize() NN_NOEXCEPT
{
    m_InputBitCount = 0;
    m_BufferedByte = 0;

    m_IntermediateHash[0] = 0x6a09e667;
    m_IntermediateHash[1] = 0xbb67ae85;
    m_IntermediateHash[2] = 0x3c6ef372;
    m_IntermediateHash[3] = 0xa54ff53a;
    m_IntermediateHash[4] = 0x510e527f;
    m_IntermediateHash[5] = 0x9b05688c;
    m_IntermediateHash[6] = 0x1f83d9ab;
    m_IntermediateHash[7] = 0x5be0cd19;

    m_State = State_Initialized;
}

inline void Sha256Impl::ProcessBlock(const void* pData) NN_NOEXCEPT
{
    ProcessBlocks(static_cast<const uint8_t*>(pData), 1);
}

void Sha256Impl::Update(const void* pData, size_t dataSize) NN_NOEXCEPT
{
    NN_SDK_REQUIRES(m_State == State_Initialized, "Invalid state. Please restart from Initialize().");

    /* 実際に処理される分だけデータサイズをビットサイズで加算していく */
    m_InputBitCount += 8 * ((m_BufferedByte + dataSize) / BlockSize) * BlockSize;

    const Bit8*  pData8 = static_cast<const Bit8*>(pData);
    size_t remaining = dataSize;

    /* 前の処理の残りがあったら1ブロックに到達するかデータが無くなるまで埋める */
    if (m_BufferedByte > 0)
    {
        size_t fillSize = std::min(BlockSize - m_BufferedByte, remaining);

        std::memcpy(m_TemporalBlockBuffer + m_BufferedByte, pData8, fillSize);
        pData8 += fillSize;
        remaining -= fillSize;
        m_BufferedByte += fillSize;
        if (m_BufferedByte == BlockSize)
        {
            ProcessBlock(m_TemporalBlockBuffer);
            m_BufferedByte = 0;
        }
    }

    // Process full blocks
    if (remaining >= BlockSize)
    {
        size_t fullBlocks = remaining / BlockSize;
        ProcessBlocks(pData8, fullBlocks);
        pData8    += fullBlocks * BlockSize;
        remaining -= fullBlocks * BlockSize;
    }

    /* ブロックサイズ以下の端数は次の処理のために保存しておく */
    if (remaining > 0)
    {
        m_BufferedByte = remaining;
        std::memcpy(m_TemporalBlockBuffer, pData8, remaining);
    }
}

void Sha256Impl::GetHash(void* pHash, size_t hashSize) NN_NOEXCEPT
{
    NN_SDK_REQUIRES((m_State == State_Initialized) || (m_State == State_Done), "Invalid state. Please restart from Initialize().");
    NN_SDK_REQUIRES(hashSize >= HashSize, "It requires %d bytes buffer", HashSize);
    NN_UNUSED(hashSize);

    if (m_State == State_Initialized)
    {
        ProcessLastBlock();
        m_State = State_Done;
    }

#if defined(NN_BUILD_CONFIG_ENDIAN_LITTLE)
    CopyDataWithSwappingEndianBy32Bit(pHash, m_IntermediateHash, HashSize);
#elif defined(NN_BUILD_CONFIG_ENDIAN_BIG)
    std::memcpy(pHash, m_IntermediateHash, HashSize);
#else
#error unknown NN_BUILD_CONFIG_ENDIAN
#endif
}


void Sha256Impl::ProcessBlocks(const uint8_t* pData, size_t blockCount) NN_NOEXCEPT
{
    uint32x4_t hash_prev0 = vld1q_u32(m_IntermediateHash);
    uint32x4_t hash_prev1 = vld1q_u32(m_IntermediateHash + 4);
    uint32x4_t hash0      = vdupq_n_u32(0);
    uint32x4_t hash1      = vdupq_n_u32(0);

    do
    {
#if defined(NN_BUILD_CONFIG_CPU_ARM64) && (defined(NN_BUILD_CONFIG_COMPILER_CLANG) || defined(NN_BUILD_COMPILER_GCC))
        __asm__ volatile
        (
            "ldp       q0, q1, [%[pData]]\n"
            "ldp       q2, q3, [%[pData], #32]\n"
            "add       %[hash0].4s, %[hash0].4s, %[hash_prev0].4s\n"
            "ldp       q16, q17, [%[RoundConstants]]\n"
            "add       %[hash1].4s, %[hash1].4s, %[hash_prev1].4s\n"
            "add       %[pData], %[pData], #64\n"

#if defined(NN_BUILD_CONFIG_ENDIAN_LITTLE)
            "rev32     v0.16b, v0.16b\n"
            "rev32     v1.16b, v1.16b\n"
            "rev32     v2.16b, v2.16b\n"
            "rev32     v3.16b, v3.16b\n"
#elif !defined(NN_BUILD_CONFIG_ENDIAN_BIG)
#error unknown NN_BUILD_CONFIG_ENDIAN
#endif

            "add       v4.4s, v0.4s, v16.4s\n"
            "add       v5.4s, v1.4s, v17.4s\n"
            "ldp       q16, q17, [%[RoundConstants], #32]\n"
            "sha256su0 v0.4s, v1.4s\n"

            "mov       %[hash_prev0].16b, %[hash0].16b\n"
            "sha256h   %q[hash0], %q[hash1], v4.4s\n"
            "mov       %[hash_prev1].16b, %[hash1].16b\n"
            "sha256h2  %q[hash1], %q[hash_prev0], v4.4s\n"
            "sha256su0 v1.4s, v2.4s\n"
            "sha256su1 v0.4s, v2.4s, v3.4s\n"

            "add       v6.4s, v2.4s, v16.4s\n"
            "mov       v18.16b, %[hash0].16b\n"
            "sha256h   %q[hash0], %q[hash1], v5.4s\n"
            "sha256h2  %q[hash1], q18, v5.4s\n"
            "sha256su0 v2.4s, v3.4s\n"
            "sha256su1 v1.4s, v3.4s, v0.4s\n"

            "add       v7.4s, v3.4s, v17.4s\n"
            "mov       v18.16b, %[hash0].16b\n"
            "ldp       q16, q17, [%[RoundConstants], #64]\n"
            "sha256h   %q[hash0], %q[hash1], v6.4s\n"
            "sha256h2  %q[hash1], q18, v6.4s\n"
            "sha256su0 v3.4s, v0.4s\n"
            "sha256su1 v2.4s, v0.4s, v1.4s\n"

            "add       v4.4s, v0.4s, v16.4s\n"
            "mov       v18.16b, %[hash0].16b\n"
            "sha256h   %q[hash0], %q[hash1], v7.4s\n"
            "sha256h2  %q[hash1], q18, v7.4s\n"
            "sha256su0 v0.4s, v1.4s\n"
            "sha256su1 v3.4s, v1.4s, v2.4s\n"

            "add       v5.4s, v1.4s, v17.4s\n"
            "mov       v18.16b, %[hash0].16b\n"
            "ldp       q16, q17, [%[RoundConstants], #96]\n"
            "sha256h   %q[hash0], %q[hash1], v4.4s\n"
            "sha256h2  %q[hash1], q18, v4.4s\n"
            "sha256su0 v1.4s, v2.4s\n"
            "sha256su1 v0.4s, v2.4s, v3.4s\n"

            "add       v6.4s, v2.4s, v16.4s\n"
            "mov       v18.16b, %[hash0].16b\n"
            "sha256h   %q[hash0], %q[hash1], v5.4s\n"
            "sha256h2  %q[hash1], q18, v5.4s\n"
            "sha256su0 v2.4s, v3.4s\n"
            "sha256su1 v1.4s, v3.4s, v0.4s\n"

            "add       v7.4s, v3.4s, v17.4s\n"
            "mov       v18.16b, %[hash0].16b\n"
            "ldp       q16, q17, [%[RoundConstants], #128]\n"
            "sha256h   %q[hash0], %q[hash1], v6.4s\n"
            "sha256h2  %q[hash1], q18, v6.4s\n"
            "sha256su0 v3.4s, v0.4s\n"
            "sha256su1 v2.4s, v0.4s, v1.4s\n"

            "add       v4.4s, v0.4s, v16.4s\n"
            "mov       v18.16b, %[hash0].16b\n"
            "sha256h   %q[hash0], %q[hash1], v7.4s\n"
            "sha256h2  %q[hash1], q18, v7.4s\n"
            "sha256su0 v0.4s, v1.4s\n"
            "sha256su1 v3.4s, v1.4s, v2.4s\n"

            "add       v5.4s, v1.4s, v17.4s\n"
            "mov       v18.16b, %[hash0].16b\n"
            "ldp       q16, q17, [%[RoundConstants], #160]\n"
            "sha256h   %q[hash0], %q[hash1], v4.4s\n"
            "sha256h2  %q[hash1], q18, v4.4s\n"
            "sha256su0 v1.4s, v2.4s\n"
            "sha256su1 v0.4s, v2.4s, v3.4s\n"

            "add       v6.4s, v2.4s, v16.4s\n"
            "mov       v18.16b, %[hash0].16b\n"
            "sha256h   %q[hash0], %q[hash1], v5.4s\n"
            "sha256h2  %q[hash1], q18, v5.4s\n"
            "sha256su0 v2.4s, v3.4s\n"
            "sha256su1 v1.4s, v3.4s, v0.4s\n"

            "add       v7.4s, v3.4s, v17.4s\n"
            "mov       v18.16b, %[hash0].16b\n"
            "ldp       q16, q17, [%[RoundConstants], #192]\n"
            "sha256h   %q[hash0], %q[hash1], v6.4s\n"
            "sha256h2  %q[hash1], q18, v6.4s\n"
            "sha256su0 v3.4s, v0.4s\n"
            "sha256su1 v2.4s, v0.4s, v1.4s\n"

            "add       v4.4s, v0.4s, v16.4s\n"
            "mov       v18.16b, %[hash0].16b\n"
            "sha256h   %q[hash0], %q[hash1], v7.4s\n"
            "sha256h2  %q[hash1], q18, v7.4s\n"
            "sha256su1 v3.4s, v1.4s, v2.4s\n"

            "add       v5.4s, v1.4s, v17.4s\n"
            "mov       v18.16b, %[hash0].16b\n"
            "ldp       q16, q17, [%[RoundConstants], #224]\n"
            "sha256h   %q[hash0], %q[hash1], v4.4s\n"
            "sha256h2  %q[hash1], q18, v4.4s\n"

            "add       v6.4s, v2.4s, v16.4s\n"
            "mov       v18.16b, %[hash0].16b\n"
            "sha256h   %q[hash0], %q[hash1], v5.4s\n"
            "sha256h2  %q[hash1], q18, v5.4s\n"

            "add       v7.4s, v3.4s, v17.4s\n"
            "mov       v18.16b, %[hash0].16b\n"
            "sha256h   %q[hash0], %q[hash1], v6.4s\n"
            "sha256h2  %q[hash1], q18, v6.4s\n"

            "mov       v18.16b, %[hash0].16b\n"
            "sha256h   %q[hash0], %q[hash1], v7.4s\n"
            "sha256h2  %q[hash1], q18, v7.4s\n"

            : [hash0]"+w"(hash0), [hash1]"+w"(hash1),
              [hash_prev0]"+w"(hash_prev0), [hash_prev1]"+w"(hash_prev1),
              [pData]"+r"(pData)
            : [RoundConstants]"r"(RoundConstants)
            : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18"
        );

#else
        // Same code as above, but with intrinsics. Not as fast because the compiler reorders the instructions in a less optimal way.

        uint32x4_t s0 = vreinterpretq_u32_u8(vld1q_u8(pData)); pData += 16;
        uint32x4_t s1 = vreinterpretq_u32_u8(vld1q_u8(pData)); pData += 16;
        uint32x4_t s2 = vreinterpretq_u32_u8(vld1q_u8(pData)); pData += 16;
        uint32x4_t s3 = vreinterpretq_u32_u8(vld1q_u8(pData)); pData += 16;

        hash0 = vaddq_u32(hash0, hash_prev0);
        hash1 = vaddq_u32(hash1, hash_prev1);
        hash_prev0 = hash0;
        hash_prev1 = hash1;

#if defined(NN_BUILD_CONFIG_ENDIAN_LITTLE)
        s0 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(s0)));
        s1 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(s1)));
        s2 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(s2)));
        s3 = vreinterpretq_u32_u8(vrev32q_u8(vreinterpretq_u8_u32(s3)));
#elif !defined(NN_BUILD_CONFIG_ENDIAN_BIG)
#error unknown NN_BUILD_CONFIG_ENDIAN
#endif

        uint32x4_t a0, a1, a2, a3, tmp;

        a0    = vaddq_u32(s0, vld1q_u32(RoundConstants));
        a1    = vaddq_u32(s1, vld1q_u32(RoundConstants +  4));
        s0    = vsha256su0q_u32(s0, s1);

        tmp   = hash0;
        hash0 = vsha256hq_u32(hash0, hash1, a0);
        hash1 = vsha256h2q_u32(hash1, tmp, a0);
        s1    = vsha256su0q_u32(s1, s2);
        s0    = vsha256su1q_u32(s0, s2, s3);

        a2    = vaddq_u32(s2, vld1q_u32(RoundConstants +  8));
        tmp   = hash0;
        hash0 = vsha256hq_u32(hash0, hash1, a1);
        hash1 = vsha256h2q_u32(hash1, tmp, a1);
        s2    = vsha256su0q_u32(s2, s3);
        s1    = vsha256su1q_u32(s1, s3, s0);

        a3    = vaddq_u32(s3, vld1q_u32(RoundConstants + 12));
        tmp   = hash0;
        hash0 = vsha256hq_u32(hash0, hash1, a2);
        hash1 = vsha256h2q_u32(hash1, tmp, a2);
        s3    = vsha256su0q_u32(s3, s0);
        s2    = vsha256su1q_u32(s2, s0, s1);

        a0    = vaddq_u32(s0, vld1q_u32(RoundConstants + 16));
        tmp   = hash0;
        hash0 = vsha256hq_u32(hash0, hash1, a3);
        hash1 = vsha256h2q_u32(hash1, tmp, a3);
        s0    = vsha256su0q_u32(s0, s1);
        s3    = vsha256su1q_u32(s3, s1, s2);

        a1    = vaddq_u32(s1, vld1q_u32(RoundConstants + 20));
        tmp   = hash0;
        hash0 = vsha256hq_u32(hash0, hash1, a0);
        hash1 = vsha256h2q_u32(hash1, tmp, a0);
        s1    = vsha256su0q_u32(s1, s2);
        s0    = vsha256su1q_u32(s0, s2, s3);

        a2    = vaddq_u32(s2, vld1q_u32(RoundConstants + 24));
        tmp   = hash0;
        hash0 = vsha256hq_u32(hash0, hash1, a1);
        hash1 = vsha256h2q_u32(hash1, tmp, a1);
        s2    = vsha256su0q_u32(s2, s3);
        s1    = vsha256su1q_u32(s1, s3, s0);

        a3    = vaddq_u32(s3, vld1q_u32(RoundConstants + 28));
        tmp   = hash0;
        hash0 = vsha256hq_u32(hash0, hash1, a2);
        hash1 = vsha256h2q_u32(hash1, tmp, a2);
        s3    = vsha256su0q_u32(s3, s0);
        s2    = vsha256su1q_u32(s2, s0, s1);

        a0    = vaddq_u32(s0, vld1q_u32(RoundConstants + 32));
        tmp   = hash0;
        hash0 = vsha256hq_u32(hash0, hash1, a3);
        hash1 = vsha256h2q_u32(hash1, tmp, a3);
        s0    = vsha256su0q_u32(s0, s1);
        s3    = vsha256su1q_u32(s3, s1, s2);

        a1    = vaddq_u32(s1, vld1q_u32(RoundConstants + 36));
        tmp   = hash0;
        hash0 = vsha256hq_u32(hash0, hash1, a0);
        hash1 = vsha256h2q_u32(hash1, tmp, a0);
        s1    = vsha256su0q_u32(s1, s2);
        s0    = vsha256su1q_u32(s0, s2, s3);

        a2    = vaddq_u32(s2, vld1q_u32(RoundConstants + 40));
        tmp   = hash0;
        hash0 = vsha256hq_u32(hash0, hash1, a1);
        hash1 = vsha256h2q_u32(hash1, tmp, a1);
        s2    = vsha256su0q_u32(s2, s3);
        s1    = vsha256su1q_u32(s1, s3, s0);

        a3    = vaddq_u32(s3, vld1q_u32(RoundConstants + 44));
        tmp   = hash0;
        hash0 = vsha256hq_u32(hash0, hash1, a2);
        hash1 = vsha256h2q_u32(hash1, tmp, a2);
        s3    = vsha256su0q_u32(s3, s0);
        s2    = vsha256su1q_u32(s2, s0, s1);

        a0    = vaddq_u32(s0, vld1q_u32(RoundConstants + 48));
        tmp   = hash0;
        hash0 = vsha256hq_u32(hash0, hash1, a3);
        hash1 = vsha256h2q_u32(hash1, tmp, a3);
        s3    = vsha256su1q_u32(s3, s1, s2);

        a1    = vaddq_u32(s1, vld1q_u32(RoundConstants + 52));
        tmp   = hash0;
        hash0 = vsha256hq_u32(hash0, hash1, a0);
        hash1 = vsha256h2q_u32(hash1, tmp, a0);

        a2    = vaddq_u32(s2, vld1q_u32(RoundConstants + 56));
        tmp   = hash0;
        hash0 = vsha256hq_u32(hash0, hash1, a1);
        hash1 = vsha256h2q_u32(hash1, tmp, a1);

        a3    = vaddq_u32(s3, vld1q_u32(RoundConstants + 60));
        tmp   = hash0;
        hash0 = vsha256hq_u32(hash0, hash1, a2);
        hash1 = vsha256h2q_u32(hash1, tmp, a2);

        tmp   = hash0;
        hash0 = vsha256hq_u32(hash0, hash1, a3);
        hash1 = vsha256h2q_u32(hash1, tmp, a3);
#endif
    }
    while (--blockCount != 0);

    hash0 = vaddq_u32(hash0, hash_prev0);
    hash1 = vaddq_u32(hash1, hash_prev1);

    vst1q_u32(m_IntermediateHash,     hash0);
    vst1q_u32(m_IntermediateHash + 4, hash1);
} // NOLINT(impl/function_size)

void Sha256Impl::ProcessLastBlock() NN_NOEXCEPT
{
    const int BlockSizeWithoutSizeField = BlockSize - sizeof(Bit64);

    /* 最後にバッファされているデータ分のデータサイズを加算 */
    m_InputBitCount += 8 * m_BufferedByte;

    /* パディングの先頭を示す 0x80 を代入 */
    m_TemporalBlockBuffer[m_BufferedByte] = 0x80;
    m_BufferedByte++;

    /* 現在計算中のブロックにサイズを埋め込む余裕があるかないかで処理が変わる */
    if (m_BufferedByte <= BlockSizeWithoutSizeField)
    {
        /* そのままサイズを格納する領域の手前までパディング */
        std::memset(m_TemporalBlockBuffer + m_BufferedByte, 0x00, BlockSizeWithoutSizeField - m_BufferedByte);
    }
    else
    {
        /* このブロックは末尾までパディングしてハッシュ計算を行う */
        std::memset(m_TemporalBlockBuffer + m_BufferedByte, 0x00, BlockSize - m_BufferedByte);
        ProcessBlock(m_TemporalBlockBuffer);

        /* 次のブロックをサイズを格納する領域の手前までパディング */
        std::memset(m_TemporalBlockBuffer, 0x00, BlockSizeWithoutSizeField);
    }

    /* 最後の 8 バイトにメッセージの長さを入れてハッシュを計算 */
    Bit64 inputBitCount = Convert64BitToBigEndian(m_InputBitCount);
    std::memcpy(m_TemporalBlockBuffer + BlockSizeWithoutSizeField, &inputBitCount, sizeof(Bit64));

    ProcessBlock(m_TemporalBlockBuffer);
}

void Sha256Impl::InitializeWithContext(const Sha256Context* pContext) NN_NOEXCEPT
{
    std::memcpy(m_IntermediateHash, pContext->_intermediateHash, sizeof(m_IntermediateHash));
    m_InputBitCount = pContext->_inputBitCount;

    m_BufferedByte = 0;
    m_State = State_Initialized;
}

size_t Sha256Impl::GetContext(Sha256Context* pContext) const NN_NOEXCEPT
{
    NN_SDK_REQUIRES(m_State == State_Initialized);

    std::memcpy(pContext->_intermediateHash, m_IntermediateHash, sizeof(m_IntermediateHash));
    pContext->_inputBitCount = m_InputBitCount;

    return m_BufferedByte;
}

}}}
