﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <nn/nn_Common.h>
#include <nn/nn_Log.h>
#include <nn/crypto/crypto_Sha256Generator.h>

#include "testCrypto_Util.h"

struct Sha256TestVector
{
    char       data[nn::crypto::Sha256Generator::BlockSize + 1];
    size_t     dataSize;
    nn::Bit8   expectedHash[nn::crypto::Sha256Generator::HashSize];
};

/* http://csrc.nist.gov/groups/ST/toolkit/documents/Examples/SHA2_Additional.pdf
   に記載されているテストベクトルに null message を追加したもの */
const Sha256TestVector sha256TestVectors[] =
{
    // #0) 0 byte (null message) (入力メッセージがない)
    {
        "",
        0,
        {0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24,
         0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55},
    },

    // #1) 1 byte 0xbd (ブロック長より短いメッセージ1)
    {
        "\xbd",
        1,
        {0x68, 0x32, 0x57, 0x20, 0xaa, 0xbd, 0x7c, 0x82, 0xf3, 0x0f, 0x55, 0x4b, 0x31, 0x3d, 0x05, 0x70,
         0xc9, 0x5a, 0xcc, 0xbb, 0x7d, 0xc4, 0xb5, 0xaa, 0xe1, 0x12, 0x04, 0xc0, 0x8f, 0xfe, 0x73, 0x2b},
    },

    // #2) 4 bytes 0xc98c8e55 (ブロック長より短いメッセージ2)
    {
        "\xc9\x8c\x8e\x55",
        4,
        {0x7a, 0xbc, 0x22, 0xc0, 0xae, 0x5a, 0xf2, 0x6c, 0xe9, 0x3d, 0xbb, 0x94, 0x43, 0x3a, 0x0e, 0x0b,
         0x2e, 0x11, 0x9d, 0x01, 0x4f, 0x8e, 0x7f, 0x65, 0xbd, 0x56, 0xc6, 0x1c, 0xcc, 0xcd, 0x95, 0x04},
    },

    // #3) 55 bytes of zeros (パディングを考慮するとちょうど1ブロックになるメッセージ)
    {
        "",
        55,
        {0x02, 0x77, 0x94, 0x66, 0xcd, 0xec, 0x16, 0x38, 0x11, 0xd0, 0x78, 0x81, 0x5c, 0x63, 0x3f, 0x21,
         0x90, 0x14, 0x13, 0x08, 0x14, 0x49, 0x00, 0x2f, 0x24, 0xaa, 0x3e, 0x80, 0xf0, 0xb8, 0x8e, 0xf7},
    },

    // #4) 56 bytes of zeros (パディングを考慮するとちょうど1ブロックを超えるメッセージ)
    {
        "",
        56,
        {0xd4, 0x81, 0x7a, 0xa5, 0x49, 0x76, 0x28, 0xe7, 0xc7, 0x7e, 0x6b, 0x60, 0x61, 0x07, 0x04, 0x2b,
         0xbb, 0xa3, 0x13, 0x08, 0x88, 0xc5, 0xf4, 0x7a, 0x37, 0x5e, 0x61, 0x79, 0xbe, 0x78, 0x9f, 0xbb},
    },

    // #5) 57 bytes of zeros (パディングを考慮すると1ブロックを超えるメッセージ)
    {
        "",
        57,
        {0x65, 0xa1, 0x6c, 0xb7, 0x86, 0x13, 0x35, 0xd5, 0xac, 0xe3, 0xc6, 0x07, 0x18, 0xb5, 0x05, 0x2e,
         0x44, 0x66, 0x07, 0x26, 0xda, 0x4c, 0xd1, 0x3b, 0xb7, 0x45, 0x38, 0x1b, 0x23, 0x5a, 0x17, 0x85},
    },

    // #6) 64 bytes of zeros (1ブロックサイズ分のメッセージ)
    {
        "",
        64,
        {0xf5, 0xa5, 0xfd, 0x42, 0xd1, 0x6a, 0x20, 0x30, 0x27, 0x98, 0xef, 0x6e, 0xd3, 0x09, 0x97, 0x9b,
         0x43, 0x00, 0x3d, 0x23, 0x20, 0xd9, 0xf0, 0xe8, 0xea, 0x98, 0x31, 0xa9, 0x27, 0x59, 0xfb, 0x4b},
    },

    // #7) 1000 bytes of zeros (長めのメッセージ)
    {
        "",
        1000,
        {0x54, 0x1b, 0x3e, 0x9d, 0xaa, 0x09, 0xb2, 0x0b, 0xf8, 0x5f, 0xa2, 0x73, 0xe5, 0xcb, 0xd3, 0xe8,
         0x01, 0x85, 0xaa, 0x4e, 0xc2, 0x98, 0xe7, 0x65, 0xdb, 0x87, 0x74, 0x2b, 0x70, 0x13, 0x8a, 0x53},
    },

    // #8) 1000 bytes of 0x41 'A' (ヌル文字列ではない長めのメッセージ)
    {
        "A",
        1000,
        {0xc2, 0xe6, 0x86, 0x82, 0x34, 0x89, 0xce, 0xd2, 0x01, 0x7f, 0x60, 0x59, 0xb8, 0xb2, 0x39, 0x31,
         0x8b, 0x63, 0x64, 0xf6, 0xdc, 0xd8, 0x35, 0xd0, 0xa5, 0x19, 0x10, 0x5a, 0x1e, 0xad, 0xd6, 0xe4},
    },

    // #9) 1005 bytes of 0x55 'U' (ヌル文字列ではない長めのメッセージ)
    {
        "U",
        1005,
        {0xf4, 0xd6, 0x2d, 0xde, 0xc0, 0xf3, 0xdd, 0x90, 0xea, 0x13, 0x80, 0xfa, 0x16, 0xa5, 0xff, 0x8d,
         0xc4, 0xc5, 0x4b, 0x21, 0x74, 0x06, 0x50, 0xf2, 0x4a, 0xfc, 0x41, 0x20, 0x90, 0x35, 0x52, 0xb0},
    },

    // #10) 1000000 bytes of zeros (比較的長めのメッセージ)
    {
        "",
        1000000,
        {0xd2, 0x97, 0x51, 0xf2, 0x64, 0x9b, 0x32, 0xff, 0x57, 0x2b, 0x5e, 0x0a, 0x9f, 0x54, 0x1e, 0xa6,
         0x60, 0xa5, 0x0f, 0x94, 0xff, 0x0b, 0xee, 0xdf, 0xb0, 0xb6, 0x92, 0xb9, 0x24, 0xcc, 0x80, 0x25},
    },
#if 0
    // #11) 0x20000000 (536870912) bytes of 0x5a 'Z' (かなり長めのメッセージ1)
    {
        "ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ",
        0x20000000,
        {0x15, 0xa1, 0x86, 0x8c, 0x12, 0xcc, 0x53, 0x95, 0x1e, 0x18, 0x23, 0x44, 0x27, 0x74, 0x47, 0xcd,
         0x09, 0x79, 0x53, 0x6b, 0xad, 0xcc, 0x51, 0x2a, 0xd2, 0x4c, 0x67, 0xe9, 0xb2, 0xd4, 0xf3, 0xdd},
    },

    // #12) 0x41000000 (1090519040) bytes of zeros (かなり長めのメッセージ2)
    {
        "",
        0x41000000,
        {0x46, 0x1c, 0x19, 0xa9, 0x3b, 0xd4, 0x34, 0x4f, 0x92, 0x15, 0xf5, 0xec, 0x64, 0x35, 0x70, 0x90,
         0x34, 0x2b, 0xc6, 0x6b, 0x15, 0xa1, 0x48, 0x31, 0x7d, 0x27, 0x6e, 0x31, 0xcb, 0xc2, 0x0b, 0x53},
    },

    // #13) 0x6000003e (1610612798) bytes of 0x42 'B' (かなり長めのメッセージ3)
    {
        "BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB",
        0x6000003e,
        {0xc2, 0x3c, 0xe8, 0xa7, 0x89, 0x5f, 0x4b, 0x21, 0xec, 0x0d, 0xaf, 0x37, 0x92, 0x0a, 0xc0, 0xa2,
         0x62, 0xa2, 0x20, 0x04, 0x5a, 0x03, 0xeb, 0x2d, 0xfe, 0xd4, 0x8e, 0xf9, 0xb0, 0x5a, 0xab, 0xea},
    },
#endif
};

const int TestVectorCount = sizeof(sha256TestVectors) / sizeof(sha256TestVectors[0]);

void Sha256BasicTest(const Sha256TestVector& testVector)
{
    nn::crypto::Sha256Generator sha256;
    nn::Bit8                    hash[nn::crypto::Sha256Generator::HashSize];

    sha256.Initialize();
    if (testVector.dataSize <= nn::crypto::Sha256Generator::BlockSize)
    {
        // 入力がブロックサイズより小さい場合は一括で計算
        sha256.Update(testVector.data, testVector.dataSize);
    }
    else
    {
        // 入力が大きい場合は繰り返し単位ごとに分割して計算
        size_t remaining = testVector.dataSize;
        while (remaining != 0 )
        {
            size_t calculationSize = std::min(std::strlen(testVector.data), remaining);
            calculationSize = (calculationSize == 0) ? 1: calculationSize; // ヌル文字列の場合は1を入力
            sha256.Update(testVector.data, calculationSize);
            remaining -= calculationSize;
        }
    }
    sha256.GetHash(hash, sizeof(hash));

    EXPECT_ARRAY_EQ(hash, testVector.expectedHash, nn::crypto::Sha256Generator::HashSize);
}

void Sha256FunctionTest(const Sha256TestVector& testVector)
{
    nn::Bit8  hash[nn::crypto::Sha256Generator::HashSize];

    if (testVector.dataSize <= nn::crypto::Sha256Generator::BlockSize)
    {
        nn::crypto::GenerateSha256Hash(hash,
                                       sizeof(hash),
                                       testVector.data,
                                       testVector.dataSize);

        EXPECT_ARRAY_EQ(hash, testVector.expectedHash, nn::crypto::Sha256Generator::HashSize);
    }
    else
    {
        char* buf = static_cast<char*>(std::malloc(testVector.dataSize));
        ASSERT_NE(nullptr, buf);

        for (size_t i=0; i<testVector.dataSize; ++i)
            buf[i] = testVector.data[0];

        nn::crypto::GenerateSha256Hash(hash, sizeof(hash), buf, testVector.dataSize);

        EXPECT_ARRAY_EQ(hash, testVector.expectedHash, nn::crypto::Sha256Generator::HashSize);

        std::free(buf);
    }
}

/**
  @brief   Sha256Generator クラスによるハッシュ計算をテストします。

  @details
  NIST が公表しているテストベクトルを用いて、
  正しいハッシュ値が計算されることを確認します。
 */
TEST(Sha256Test, Basic)
{
    for (int i = 0; i < TestVectorCount; i++)
    {
        Sha256BasicTest(sha256TestVectors[i]);
    }
}

/**
  @brief   GenerateSha256Hash 関数によるハッシュ計算をテストします。

  @details
  NIST が公表しているテストベクトルを用いて、
  正しいハッシュ値が計算されることを確認します。
 */
TEST(Sha256Test, FunctionInterface)
{
    // 短いテストベクトルについてのみテストする
    for (int i = 0; i < TestVectorCount; i++)
    {
        Sha256FunctionTest(sha256TestVectors[i]);
    }
}

/**
  @brief   Sha256Generator クラスの状態遷移をテストします。
 */
TEST(Sha256Test, StateTransition)
{
    nn::crypto::Sha256Generator sha256;
    nn::Bit8                    hash[nn::crypto::Sha256Generator::HashSize];
    nn::Bit8                    hash2[nn::crypto::Sha256Generator::HashSize];

#if defined(NN_SDK_BUILD_DEBUG) || defined(NN_SDK_BUILD_DEVELOP)
    // 初期化せずに Update が呼ばれたら NG
    EXPECT_DEATH_IF_SUPPORTED(sha256.Update(sha256TestVectors[0].data, sha256TestVectors[0].dataSize), "");

    // 初期化せずに GetHash が呼ばれても NG
    EXPECT_DEATH_IF_SUPPORTED(sha256.GetHash(hash, sizeof(hash)), "");
#endif

    // 初期化
    sha256.Initialize();

    // 初期化の後に GetHash が呼ばれても大丈夫
    sha256.GetHash(hash, sizeof(hash));

    // GetHash は連続で呼んでも大丈夫で、同じ値が出力されるはず
    sha256.GetHash(hash2, sizeof(hash2));

    EXPECT_ARRAY_EQ(hash, hash2, nn::crypto::Sha256Generator::HashSize);

#if defined(NN_SDK_BUILD_DEBUG) || defined(NN_SDK_BUILD_DEVELOP)
    // GetHash の後に Update が呼ばれたら NG (Release ビルドでは ASSERT が無効のため停止しない)
    EXPECT_DEATH_IF_SUPPORTED(sha256.Update(sha256TestVectors[0].data, sha256TestVectors[0].dataSize), "");
#endif
}

/**
  @brief   非 32bit アラインメントの入出力バッファに対する挙動をテストします。
 */
TEST(Sha256Test, UnalignedIoBuffer)
{
    nn::crypto::Sha256Generator sha256;
    nn::Bit8                    hash[nn::crypto::Sha256Generator::HashSize + sizeof(nn::Bit64)] = {0};
    nn::Bit8                    data[nn::crypto::Sha256Generator::BlockSize + sizeof(nn::Bit64)] = {0};

    // http://csrc.nist.gov/groups/ST/toolkit/documents/Examples/SHA256.pdf にある、
    // 処理する位置・順序がずれたら計算結果が変わってしまうようなテストベクトルを使用する
    Sha256TestVector testVector =
    {
        "abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq",
        56,
        {0x24, 0x8D, 0x6A, 0x61, 0xD2, 0x06, 0x38, 0xB8, 0xE5, 0xC0, 0x26, 0x93, 0x0C, 0x3E, 0x60, 0x39,
         0xA3, 0x3C, 0xE4, 0x59, 0x64, 0xFF, 0x21, 0x67, 0xF6, 0xEC, 0xED, 0xD4, 0x19, 0xDB, 0x06, 0xC1},
    };

    // 非 32bit アラインメントの入力バッファに対して問題ないことをテストする
    for (int i = 0; i < static_cast<int>(sizeof(nn::Bit64)); i++)
    {
        nn::Bit8* head = data + i;

        std::memcpy(head, testVector.data, testVector.dataSize);
        sha256.Initialize();
        sha256.Update(head, testVector.dataSize);
        sha256.GetHash(hash, nn::crypto::Sha256Generator::HashSize);

        EXPECT_ARRAY_EQ(hash, testVector.expectedHash, nn::crypto::Sha256Generator::HashSize);
    }

    // 非 32bit アラインメントの出力バッファに対して問題ないことをテストする
    for (int i = 0; i < static_cast<int>(sizeof(nn::Bit64)); i++)
    {
        nn::Bit8* head = hash + i;

        sha256.Initialize();
        sha256.Update(testVector.data, testVector.dataSize);
        sha256.GetHash(head, nn::crypto::Sha256Generator::HashSize);

        EXPECT_ARRAY_EQ(head, testVector.expectedHash, nn::crypto::Sha256Generator::HashSize);
    }
}

/**
  @brief   ハッシュ計算の中断と再開をテストします。
 */
TEST(Sha256Test, SuspendResume)
{
    // #8) 1000 bytes of 0x41 'A' (ヌル文字列ではない長めのメッセージ)
    Sha256TestVector testVector =
    {
        "A",
        1000,
        {0xc2, 0xe6, 0x86, 0x82, 0x34, 0x89, 0xce, 0xd2, 0x01, 0x7f, 0x60, 0x59, 0xb8, 0xb2, 0x39, 0x31,
         0x8b, 0x63, 0x64, 0xf6, 0xdc, 0xd8, 0x35, 0xd0, 0xa5, 0x19, 0x10, 0x5a, 0x1e, 0xad, 0xd6, 0xe4},
    };

    for (size_t suspendPoint = 64; suspendPoint < testVector.dataSize; suspendPoint += 64)
    {
        NN_LOG("Test suspend/resume: expected suspend size = %d byte\n", suspendPoint);

        nn::crypto::Sha256Context context;

        // 初期化して suspendPoint まで計算し、コンテキストを取得する
        {
            nn::crypto::Sha256Generator sha256;

            sha256.Initialize();

            size_t remaining = suspendPoint;
            while (remaining != 0 )
            {
                size_t calculationSize = std::min(std::strlen(testVector.data), remaining);
                calculationSize = (calculationSize == 0) ? 1: calculationSize; // ヌル文字列の場合は1を入力
                sha256.Update(testVector.data, calculationSize);
                remaining -= calculationSize;
            }

            // ブロックサイズの整数倍を入力しているのでここでは返り値は 0 であることが期待される
            EXPECT_TRUE(sha256.GetContext(&context) == 0);
        }

        // InitializeWithContext で初期化して計算を再開する
        {
            nn::crypto::Sha256Generator sha256;
            nn::Bit8                    hash[nn::crypto::Sha256Generator::HashSize];

            sha256.InitializeWithContext(&context);

            size_t remaining = testVector.dataSize - suspendPoint;
            while (remaining != 0 )
            {
                size_t calculationSize = std::min(std::strlen(testVector.data), remaining);
                calculationSize = (calculationSize == 0) ? 1: calculationSize; // ヌル文字列の場合は1を入力
                sha256.Update(testVector.data, calculationSize);
                remaining -= calculationSize;
            }

            sha256.GetHash(hash, nn::crypto::Sha256Generator::HashSize);

            EXPECT_ARRAY_EQ(hash, testVector.expectedHash, nn::crypto::Sha256Generator::HashSize);
        }
    }
}

/**
  @brief   ハッシュ計算の中断と再開をバッファリングがあった場合についてテストします。
 */
TEST(Sha256Test, SuspendResumeUnaligned)
{
    // #8) 1000 bytes of 0x41 'A' (ヌル文字列ではない長めのメッセージ)
    Sha256TestVector testVector =
    {
        "A",
        1000,
        {0xc2, 0xe6, 0x86, 0x82, 0x34, 0x89, 0xce, 0xd2, 0x01, 0x7f, 0x60, 0x59, 0xb8, 0xb2, 0x39, 0x31,
         0x8b, 0x63, 0x64, 0xf6, 0xdc, 0xd8, 0x35, 0xd0, 0xa5, 0x19, 0x10, 0x5a, 0x1e, 0xad, 0xd6, 0xe4},
    };

    for (size_t suspendPoint = 0; suspendPoint <= 128; suspendPoint++)
    {
        //NN_LOG("Test suspend/resume: expected suspend size = %d byte\n", suspendPoint);

        nn::crypto::Sha256Context context;
        uint8_t                   bufferedData[nn::crypto::Sha256Generator::BlockSize];
        size_t                    bufferedDataSize = 0;

        // 初期化して suspendPoint まで計算し、コンテキストを取得する
        {
            nn::crypto::Sha256Generator sha256;

            sha256.Initialize();

            size_t remaining = suspendPoint;
            while (remaining != 0 )
            {
                size_t calculationSize = std::min(std::strlen(testVector.data), remaining);
                calculationSize = (calculationSize == 0) ? 1: calculationSize; // ヌル文字列の場合は1を入力
                sha256.Update(testVector.data, calculationSize);
                remaining -= calculationSize;
            }

            size_t buffered = sha256.GetContext(&context);

            // バッファされているデータがある場合は別途保存する
            if (buffered != 0)
            {
                sha256.GetBufferedData(bufferedData, sizeof(bufferedData));
                bufferedDataSize = sha256.GetBufferedDataSize();
            }
        }

        // InitializeWithContext で初期化して計算を再開する
        {
            nn::crypto::Sha256Generator sha256;
            nn::Bit8                    hash[nn::crypto::Sha256Generator::HashSize];

            sha256.InitializeWithContext(&context);

            // バッファされているデータがあった場合は最初に入力する
            if (bufferedDataSize != 0)
            {
                sha256.Update(bufferedData, bufferedDataSize);
            }

            size_t remaining = testVector.dataSize - suspendPoint;
            while (remaining != 0 )
            {
                size_t calculationSize = std::min(std::strlen(testVector.data), remaining);
                calculationSize = (calculationSize == 0) ? 1: calculationSize; // ヌル文字列の場合は1を入力
                sha256.Update(testVector.data, calculationSize);
                remaining -= calculationSize;
            }

            sha256.GetHash(hash, nn::crypto::Sha256Generator::HashSize);

            EXPECT_ARRAY_EQ(hash, testVector.expectedHash, nn::crypto::Sha256Generator::HashSize);
        }
    }
}

/**
  @brief   デストラクタで内部データがクリアされることをテストします。
 */
TEST(Sha256Test, Destructor)
{
    nn::crypto::Sha256Generator sha;
    sha.Initialize();

    // 明示的にデストラクタを呼んで呼び出し前後でのメモリクリアを確認する
    EXPECT_ARRAY_NONZERO(&sha, sizeof(sha));
    sha.~Sha256Generator();
    EXPECT_ARRAY_ZERO(&sha, sizeof(sha));
}
