﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <cstring>
#include <algorithm>
#include <nn/nn_SdkAssert.h>
#include <nn/nn_Common.h>
#include <nn/nn_Log.h>
#include <nn/crypto/crypto_Config.h>
#include <nn/crypto/detail/crypto_Sha1Impl.h>
#include <nn/crypto/detail/crypto_Clear.h>

#include "crypto_Util.h"
#include <arm_neon.h>

namespace nn { namespace crypto { namespace detail {

namespace
{
    /* 各ラウンドの処理で使用する定数 */
    const uint32_t RoundConstants[4] =
    {
        0x5a827999,
        0x6ed9eba1,
        0x8f1bbcdc,
        0xca62c1d6
    };

}   // anonymous namespace


Sha1Impl::~Sha1Impl() NN_NOEXCEPT
{
    ClearMemory(this, sizeof(*this));
}

void Sha1Impl::Initialize() NN_NOEXCEPT
{
    m_InputBitCount = 0;
    m_BufferedByte = 0;

    m_IntermediateHash[0] = 0x67452301;
    m_IntermediateHash[1] = 0xefcdab89;
    m_IntermediateHash[2] = 0x98badcfe;
    m_IntermediateHash[3] = 0x10325476;
    m_IntermediateHash[4] = 0xc3d2e1f0;

    m_State = State_Initialized;
}

void Sha1Impl::Update(const void* pData, size_t dataSize) NN_NOEXCEPT
{
    NN_SDK_REQUIRES(m_State == State_Initialized, "Invalid state. Please restart from Initialize().");

    /* 実際に処理される分だけデータサイズをビットサイズで加算していく */
    m_InputBitCount += 8 * ((m_BufferedByte + dataSize) / BlockSize) * BlockSize;

    const Bit8*  pData8 = static_cast<const Bit8*>(pData);
    size_t remaining = dataSize;

    /* 前の処理の残りがあったら1ブロックに到達するかデータが無くなるまで埋める */
    if (m_BufferedByte > 0)
    {
        size_t fillSize = std::min(BlockSize - m_BufferedByte, remaining);

        std::memcpy(m_TemporalBlockBuffer + m_BufferedByte, pData8, fillSize);
        pData8 += fillSize;
        remaining -= fillSize;
        m_BufferedByte += fillSize;
        if (m_BufferedByte == BlockSize)
        {
            ProcessBlock(m_TemporalBlockBuffer);
            m_BufferedByte = 0;
        }
    }

    /* ブロックサイズ以上の残りがある場合はブロックサイズごとに処理 */
    while (remaining >= BlockSize)
    {
        ProcessBlock(pData8);
        pData8 += BlockSize;
        remaining -= BlockSize;
    }

    /* ブロックサイズ以下の端数は次の処理のために保存しておく */
    if (remaining > 0)
    {
        m_BufferedByte = remaining;
        std::memcpy(m_TemporalBlockBuffer, pData8, remaining);
    }
}

void Sha1Impl::GetHash(void* pHash, size_t hashSize) NN_NOEXCEPT
{
    NN_SDK_REQUIRES((m_State == State_Initialized) || (m_State == State_Done), "Invalid state. Please restart from Initialize().");
    NN_SDK_REQUIRES(hashSize >= HashSize, "It requires %d bytes buffer", HashSize);
    NN_UNUSED(hashSize);

    if (m_State == State_Initialized)
    {
        ProcessLastBlock();
        m_State = State_Done;
    }

#if defined(NN_BUILD_CONFIG_ENDIAN_LITTLE)
    CopyDataWithSwappingEndianBy32Bit(pHash, m_IntermediateHash, HashSize);
#elif defined(NN_BUILD_CONFIG_ENDIAN_BIG)
    std::memcpy(pHash, m_IntermediateHash, HashSize);
#else
#error unknown NN_BUILD_CONFIG_ENDIAN
#endif
}

void Sha1Impl::ProcessBlock(const void* pData) NN_NOEXCEPT
{
    uint32x4_t      hash_prev0;
    uint32_t        hash_prev1;
    uint32x4_t      hash_abcd;
    uint32_t        hash_e;
    uint32x4_t      w[20];
    uint32x4_t      k[4];
    int             t;

    hash_prev0 = vld1q_u32(m_IntermediateHash);
    hash_prev1 = m_IntermediateHash[4];

    /* 最初の 16 word には入力データをそのまま格納する */
    const uint8_t*  pData8 = static_cast<const uint8_t*>(pData);
    for (t = 0; t < 4; t++)
    {
        /* 8x16 のベクタとして解釈させる */
        uint8x16_t tmp = vld1q_u8(pData8);
#if defined(NN_BUILD_CONFIG_ENDIAN_LITTLE)
        /* リトルエンディアンのときはエンディアンを変換 */
        tmp = vrev32q_u8(tmp);
#elif defined(NN_BUILD_CONFIG_ENDIAN_BIG)
#else
#error unknown NN_BUILD_CONFIG_ENDIAN
#endif
        /* それを 32x4 のベクタに再解釈させる */
        w[t] = vreinterpretq_u32_u8(tmp);

        pData8 += 16;
    }

    /* 残りの 64 word はスケジュール更新命令で計算されたデータを格納する */
    for (; t < 20; t++)
    {
        w[t] = vsha1su0q_u32(w[t - 4], w[t - 3], w[t - 2]);
        w[t] = vsha1su1q_u32(w[t], w[t - 1]);
    }

    /* 各ラウンドで使用する定数を保持するベクタを作る */
    for (t = 0; t < 4; t++)
    {
        k[t] = vdupq_n_u32(RoundConstants[t]);
    }

    /* ハッシュ計算の実体 */
    hash_abcd = hash_prev0;
    hash_e    = hash_prev1;

    /* 最初の 20 ラウンド分は VSHA1C */
    for (t = 0; t < 5; t++)
    {
        uint32x4_t wk = vaddq_u32(w[t], k[0]);
        uint32_t   hash_a = vgetq_lane_u32(hash_abcd, 0);

        hash_abcd = vsha1cq_u32(hash_abcd, hash_e, wk);
        hash_e    = vsha1h_u32(hash_a);
    }

    /* 次の 20 ラウンド分は VSHA1P */
    for (; t < 10; t++)
    {
        uint32x4_t wk = vaddq_u32(w[t], k[1]);
        uint32_t   hash_a = vgetq_lane_u32(hash_abcd, 0);

        hash_abcd = vsha1pq_u32(hash_abcd, hash_e, wk);
        hash_e    = vsha1h_u32(hash_a);
    }

    /* その次の 20 ラウンド分は VSHA1M */
    for (; t < 15; t++)
    {
        uint32x4_t wk = vaddq_u32(w[t], k[2]);
        uint32_t   hash_a = vgetq_lane_u32(hash_abcd, 0);

        hash_abcd = vsha1mq_u32(hash_abcd, hash_e, wk);
        hash_e    = vsha1h_u32(hash_a);
    }

    /* 最後の 20 ラウンド分は VSHA1P */
    for (; t < 20; t++)
    {
        uint32x4_t wk = vaddq_u32(w[t], k[3]);
        uint32_t   hash_a = vgetq_lane_u32(hash_abcd, 0);

        hash_abcd = vsha1pq_u32(hash_abcd, hash_e, wk);
        hash_e    = vsha1h_u32(hash_a);
    }

    /* 計算された値を計算前の中間ハッシュ値に足し込んで計算完了 */
    hash_abcd = vaddq_u32(hash_abcd, hash_prev0);
    hash_e    = hash_e + hash_prev1;

    vst1q_u32(m_IntermediateHash, hash_abcd);
    m_IntermediateHash[4] = hash_e;
}

void Sha1Impl::ProcessLastBlock() NN_NOEXCEPT
{
    const int BlockSizeWithoutSizeField = BlockSize - sizeof(Bit64);

    /* 最後にバッファされているデータ分のデータサイズを加算 */
    m_InputBitCount += 8 * m_BufferedByte;

    /* パディングの先頭を示す 0x80 を代入 */
    m_TemporalBlockBuffer[m_BufferedByte] = 0x80;
    m_BufferedByte++;

    /* 現在計算中のブロックにサイズを埋め込む余裕があるかないかで処理が変わる */
    if (m_BufferedByte <= BlockSizeWithoutSizeField)
    {
        /* そのままサイズを格納する領域の手前までパディング */
        std::memset(m_TemporalBlockBuffer + m_BufferedByte, 0x00, BlockSizeWithoutSizeField - m_BufferedByte);
    }
    else
    {
        /* このブロックは末尾までパディングしてハッシュ計算を行う */
        std::memset(m_TemporalBlockBuffer + m_BufferedByte, 0x00, BlockSize - m_BufferedByte);
        ProcessBlock(m_TemporalBlockBuffer);

        /* 次のブロックをサイズを格納する領域の手前までパディング */
        std::memset(m_TemporalBlockBuffer, 0x00, BlockSizeWithoutSizeField);
    }

    /* 最後の 8 バイトにメッセージの長さを入れてハッシュを計算 */
    Bit64 inputBitCount = Convert64BitToBigEndian(m_InputBitCount);
    std::memcpy(m_TemporalBlockBuffer + BlockSizeWithoutSizeField, &inputBitCount, sizeof(Bit64));

    ProcessBlock(m_TemporalBlockBuffer);
}

}}}
