﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/nn_Common.h>
#include <nn/fssystem/fs_AesCtrStorage.h>
#include <nn/fssystem/fs_AsynchronousAccess.h>

#include <nnt/nntest.h>
#include <nnt/fsUtil/testFs_util.h>

#include "testFs_Unit_StorageLayerTestCase.h"

struct FormatTestParam
{
    int64_t baseStorageSize;
    char key[nn::fssystem::AesCtrStorage::KeySize];
    char iv[nn::fssystem::AesCtrStorage::IvSize];
    int blockSize;
};

nnt::fs::util::Vector<FormatTestParam> MakeTestParam() NN_NOEXCEPT
{
    nnt::fs::util::Vector<FormatTestParam> params;

    {
        FormatTestParam data;
        data.baseStorageSize = 4 * 1024 * 1024;
        memcpy(data.key, "1234567890123456", nn::fssystem::AesCtrStorage::KeySize);
        memcpy(data.iv,  "6543210987654321", nn::fssystem::AesCtrStorage::IvSize);
        data.blockSize = 128 * 1024;
        params.push_back(data);
    }

    {
        FormatTestParam data;
        data.baseStorageSize = 512 * 1024;
        memcpy(data.key, "aaaabbbbccccdddd", nn::fssystem::AesCtrStorage::KeySize);
        memcpy(data.iv,  "zzzzyyyyxxxxwwww", nn::fssystem::AesCtrStorage::IvSize);
        data.blockSize = 16 * 1024;
        params.push_back(data);
    }

    {
        FormatTestParam data;
        data.baseStorageSize = 4 * 1024 * 1024;
        memcpy(data.key, "abcdefghijklmnop", nn::fssystem::AesCtrStorage::KeySize);
        memcpy(data.iv,  "klmnopqrstuvwxyz", nn::fssystem::AesCtrStorage::IvSize);
        data.blockSize = 16 * 1024;
        params.push_back(data);
    }

    {
        FormatTestParam data;
        data.baseStorageSize = 512 * 1024;
        memcpy(data.key, "aaaabbbbccccdddd", nn::fssystem::AesCtrStorage::KeySize);
        memcpy(data.iv, "zzzzyyyyxxxxwwww", nn::fssystem::AesCtrStorage::IvSize);
        data.blockSize = 16 * 1024;
        params.push_back(data);
    }

    {
        FormatTestParam data;
        data.baseStorageSize = 4 * 1024 * 1024;
        memcpy(data.key, "\\^-[@]:;/.,!\"#$%", nn::fssystem::AesCtrStorage::KeySize);
        memcpy(data.iv,  "|~={`}*+_?><)('&", nn::fssystem::AesCtrStorage::IvSize);
        data.blockSize = 16 * 1024;
        params.push_back(data);
    }

    return params;
}

nnt::fs::util::Vector<FormatTestParam> MakeLargeSizeTestParam() NN_NOEXCEPT
{
    nnt::fs::util::Vector<FormatTestParam> params;

    {
        FormatTestParam data;
        data.blockSize = 16 * 1024;
        data.baseStorageSize = static_cast<int64_t>(64) * 1024 * 1024 * 1024 + data.blockSize;
        memcpy(data.key, "1234567890123456", nn::fssystem::AesCtrStorage::KeySize);
        memcpy(data.iv,  "6543210987654321", nn::fssystem::AesCtrStorage::IvSize);
        params.push_back(data);
    }

    return params;
}

class FsAesCtrStorageTestBase : public ::testing::TestWithParam< FormatTestParam >
{
public:

    /**
    * @brief    パラメータにしたがって AesCtr を初期化する
    */
    virtual void SetUp() NN_NOEXCEPT NN_OVERRIDE
    {
        FormatTestParam param = GetParam();

        m_pCheckStorage.reset(
            new nnt::fs::util::WriteSizeCheckStorage(
                GetBaseStorage(),
                1,
                nn::fssystem::PooledBuffer::GetAllocatableSizeMax(),
                [param](nn::fs::IStorage* pBaseStorage) NN_NOEXCEPT
                {
                    std::unique_ptr<nn::fs::IStorage> pAesCtrStorage(
                        new nn::fssystem::AesCtrStorage(
                            pBaseStorage,
                            param.key,
                            nn::fssystem::AesCtrStorage::KeySize,
                            param.iv,
                            nn::fssystem::AesCtrStorage::IvSize
                        )
                    );
                    return pAesCtrStorage;
                }
            )

        );
    }

    nnt::fs::util::WriteSizeCheckStorage* GetStorage() NN_NOEXCEPT
    {
        return m_pCheckStorage.get();
    }

    virtual nn::fs::IStorage* GetBaseStorage() NN_NOEXCEPT = 0;

private:
    std::unique_ptr<nnt::fs::util::WriteSizeCheckStorage> m_pCheckStorage;
};

class FsAesCtrStorageTest : public FsAesCtrStorageTestBase
{
public:

    virtual void SetUp() NN_NOEXCEPT NN_OVERRIDE
    {
        FormatTestParam param = GetParam();
        m_BaseStorage.Initialize(param.baseStorageSize);

        FsAesCtrStorageTestBase::SetUp();
    }

    virtual nn::fs::IStorage* GetBaseStorage() NN_NOEXCEPT NN_OVERRIDE
    {
        return &m_BaseStorage;
    }

    nnt::fs::util::AccessCountedMemoryStorage* GetBaseMemoryStorage() NN_NOEXCEPT
    {
        return &m_BaseStorage;
    }

private:
    nnt::fs::util::AccessCountedMemoryStorage m_BaseStorage;
};

class FsAesCtrStorageLargeSizeTest : public FsAesCtrStorageTestBase
{
public:

    virtual void SetUp() NN_NOEXCEPT NN_OVERRIDE
    {
        FormatTestParam param = GetParam();

        m_Storage.Initialize(param.baseStorageSize);

        FsAesCtrStorageTestBase::SetUp();
    }

    virtual void TearDown() NN_NOEXCEPT NN_OVERRIDE
    {
        m_Storage.Finalize();
    }

    virtual nn::fs::IStorage* GetBaseStorage() NN_NOEXCEPT NN_OVERRIDE
    {
        return &m_Storage;
    }

private:
    nnt::fs::util::VirtualMemoryStorage m_Storage;
};

// 全体的なデータの読み書きを行います。
TEST_P(FsAesCtrStorageTest, TestReadWrite)
{
    FormatTestParam param = GetParam();
    size_t baseStorageSize = static_cast<size_t>(param.baseStorageSize);
    std::unique_ptr<char[]> writeBuffer(new char[baseStorageSize]);
    std::unique_ptr<char[]> readBuffer(new char[baseStorageSize]);

    nnt::fs::util::FillBufferWithRandomValue(writeBuffer.get(), baseStorageSize);

    NNT_ASSERT_RESULT_SUCCESS(
        GetStorage()->Write(0, writeBuffer.get(), baseStorageSize)
    );
    NNT_ASSERT_RESULT_SUCCESS(
        GetStorage()->Read(0, readBuffer.get(), baseStorageSize)
    );
    ASSERT_EQ(0, memcmp(writeBuffer.get(), readBuffer.get(), baseStorageSize));
}

// キャッシュ無効化要求が下位ストレージに到達することをテストします。
TEST_P(FsAesCtrStorageTest, TestInvalidate)
{
    GetBaseMemoryStorage()->ResetAccessCounter();
    NNT_EXPECT_RESULT_SUCCESS(GetStorage()->OperateRange(
        nn::fs::OperationId::Invalidate,
        0,
        GetParam().baseStorageSize));
    EXPECT_GT(GetBaseMemoryStorage()->GetInvalidateTimes(), 0);
}

// マルチスレッドで読み書きを行います
TEST_P(FsAesCtrStorageTest, TestMultiThreadReadWrite)
{
    static const size_t StackSize = 32 * 1024;
    static const size_t ThreadCount = 32;
    NN_OS_ALIGNAS_THREAD_STACK static char s_ThreadStack[StackSize * ThreadCount] = {};

    FormatTestParam param = GetParam();
    size_t baseStorageSize = static_cast<size_t>(param.baseStorageSize);

    GetStorage()->SetChecked(false);

    struct ThreadArgument
    {
        int id;
        nn::fs::IStorage* pStorage;
        size_t size;
    };

    static nn::os::ThreadType threads[ThreadCount] = {};
    static ThreadArgument argument[ThreadCount] = {};

    std::unique_ptr<char[]> writeBuffer(new char[baseStorageSize]);
    std::unique_ptr<char[]> readBuffer(new char[baseStorageSize]);
    nnt::fs::util::FillBufferWithRandomValue(writeBuffer.get(), baseStorageSize);

    auto ThreadFunction = [](void* ptr) NN_NOEXCEPT
    {
        ThreadArgument* pArgument = reinterpret_cast<ThreadArgument*>(ptr);
        int64_t offset = pArgument->id * pArgument->size;

        std::unique_ptr<char[]> writeBuffer(new char[pArgument->size]);
        std::unique_ptr<char[]> readBuffer(new char[pArgument->size]);

        // Write し Read する処理を適当な回数繰り返す
        int loopCount = 30;
        for( auto i = 0; i < loopCount; ++i )
        {
            nnt::fs::util::FillBufferWithRandomValue(writeBuffer.get(), pArgument->size);
            NNT_ASSERT_RESULT_SUCCESS(
                pArgument->pStorage->Write(offset, writeBuffer.get(), pArgument->size)
            );
            NNT_ASSERT_RESULT_SUCCESS(
                pArgument->pStorage->Read(offset, readBuffer.get(), pArgument->size)
            );

            // 読み込んだ値が書き込んだ値と同じかチェック
            ASSERT_EQ(0, memcmp(writeBuffer.get(), readBuffer.get(), pArgument->size));
        }
    };

    for( int i = 0; i < ThreadCount; ++i )
    {
        argument[i].id = i;
        argument[i].pStorage = GetStorage();
        argument[i].size = baseStorageSize / ThreadCount;

        NNT_ASSERT_RESULT_SUCCESS(
            nn::os::CreateThread(
                &threads[i],
                ThreadFunction,
                &argument[i],
                s_ThreadStack + StackSize * i,
                StackSize,
                nn::os::DefaultThreadPriority - i
            )
        );
    }

    // スレッド開始
    for( auto& thread : threads )
    {
        nn::os::StartThread(&thread);
    }

    // スレッド終了
    for( auto& thread : threads )
    {
        nn::os::WaitThread(&thread);
        nn::os::DestroyThread(&thread);
    }
}

// 書き込んだデータを改竄してから読み込みます
TEST_P(FsAesCtrStorageTest, TestFalsification)
{
    FormatTestParam param = GetParam();
    size_t baseStorageSize = static_cast<size_t>(param.baseStorageSize);

    std::unique_ptr<char[]> writeBuffer(new char[baseStorageSize]);
    std::unique_ptr<char[]> readBuffer(new char[baseStorageSize]);

    nnt::fs::util::FillBufferWithRandomValue(writeBuffer.get(), baseStorageSize);

    NNT_ASSERT_RESULT_SUCCESS(
        GetStorage()->Write(0, writeBuffer.get(), baseStorageSize)
    );

    // 部分的にデータをインクリメントする
    static const size_t FalsificationSize = 8;
    std::mt19937 mt(nnt::fs::util::GetRandomSeed());
    int falsificationBlock
        = std::uniform_int_distribution<>(
              0,
              static_cast<int>(baseStorageSize / param.blockSize - 1)
          )(mt);
    int writeOffset
        = std::uniform_int_distribution<>(
              0,
              static_cast<int>(param.blockSize - FalsificationSize - 1)
          )(mt) + falsificationBlock * static_cast<int>(param.blockSize);

    std::unique_ptr<char[]> buffer(new char[FalsificationSize]);
    memcpy(buffer.get(), &(writeBuffer.get()[writeOffset]), FalsificationSize);

    for( int i = 0; i < static_cast<int>(FalsificationSize); ++i )
    {
        ++buffer.get()[i];
    }

    // インクリメントしたデータを直接書き込む
    NNT_ASSERT_RESULT_SUCCESS(
        GetBaseStorage()->Write(writeOffset, buffer.get(), FalsificationSize)
    );

    // 全体読み込みをして、正常に読み込めるべき部分は正常に読み込めたかチェックします
    {
        NNT_ASSERT_RESULT_SUCCESS(
            GetStorage()->Read(0, readBuffer.get(), baseStorageSize)
        );

        // 改鼠箇所より前
        ASSERT_EQ(
            0,
            memcmp(
                writeBuffer.get(),
                readBuffer.get(),
                writeOffset
            )
        );

        // 改鼠箇所より後
        ASSERT_EQ(
            0,
            memcmp(
                writeBuffer.get() + writeOffset + FalsificationSize,
                readBuffer.get() + writeOffset + FalsificationSize,
                baseStorageSize - (writeOffset + FalsificationSize)
            )
        );

        // 改鼠箇所
        ASSERT_NE(
            0,
            memcmp(
                writeBuffer.get() + writeOffset,
                readBuffer.get() + writeOffset,
                FalsificationSize
            )
        );


    }

    // 部分読み込みをして、正常に読み込めるべき部分は正常に読み込めたかチェックします
    {
        int64_t offset = 0;
        // 改鼠箇所を含むブロックより前
        NNT_ASSERT_RESULT_SUCCESS(
            GetStorage()->Read(offset, readBuffer.get(), param.blockSize * falsificationBlock)
        );
        ASSERT_EQ(
            offset,
            memcmp(
                writeBuffer.get(),
                readBuffer.get(),
                param.blockSize * falsificationBlock
            )
        );

        // 改鼠箇所を含むブロックより後
        offset = param.blockSize * (falsificationBlock + 1);
        NNT_ASSERT_RESULT_SUCCESS(
            GetStorage()->Read(offset, readBuffer.get(), baseStorageSize - static_cast<size_t>(offset))
        );
        ASSERT_EQ(
            0,
            memcmp(
                writeBuffer.get() + offset,
                readBuffer.get(),
                baseStorageSize - static_cast<size_t>(offset)
            )
        );

        // 改鼠箇所を含むブロック
        offset = param.blockSize * falsificationBlock;
        NNT_ASSERT_RESULT_SUCCESS(
            GetStorage()->Read(offset, readBuffer.get(), param.blockSize)
        );

        char* blockWriteBuffer = writeBuffer.get() + offset;
        int64_t blockWriteOffset = writeOffset - offset;

        // 改鼠箇所より前
        ASSERT_EQ(
            0,
            memcmp(
                blockWriteBuffer,
                readBuffer.get(),
                static_cast<size_t>(blockWriteOffset)
            )
        );

        // 改鼠箇所より後
        ASSERT_EQ(
            0,
            memcmp(
                blockWriteBuffer + blockWriteOffset + FalsificationSize,
                readBuffer.get() + blockWriteOffset + FalsificationSize,
                param.blockSize - (static_cast<size_t>(blockWriteOffset) + FalsificationSize)
            )
        );

        // 改鼠箇所
        ASSERT_NE(
            0,
            memcmp(
                blockWriteBuffer + blockWriteOffset,
                readBuffer.get() + blockWriteOffset,
                FalsificationSize
            )
        );
    }
}// NOLINT(impl/function_size)

// PooledBuffer 内のデータを書き込めるかチェックします
TEST_P(FsAesCtrStorageTest, TestWritePooledBuffer)
{
    FormatTestParam param = GetParam();
    const size_t WriteDataSize = param.blockSize * 4;

    std::unique_ptr<char[]> writeBuffer(new char[WriteDataSize]);
    std::unique_ptr<char[]> readBuffer(new char[WriteDataSize]);
    nnt::fs::util::FillBufferWithRandomValue(writeBuffer.get(), WriteDataSize);

    nn::fssystem::PooledBuffer poolBuffer;
    poolBuffer.Allocate(WriteDataSize, WriteDataSize);
    memcpy(poolBuffer.GetBuffer(), writeBuffer.get(), WriteDataSize);

    NNT_ASSERT_RESULT_SUCCESS(
        GetStorage()->Write(0, poolBuffer.GetBuffer(), WriteDataSize)
    );
    NNT_ASSERT_RESULT_SUCCESS(
        GetStorage()->Read(0, readBuffer.get(), WriteDataSize)
    );

    ASSERT_EQ(
        0,
        memcmp(writeBuffer.get(), readBuffer.get(), WriteDataSize)
    );
}

// PooledBuffer 内のデータをマルチスレッドで読み書きできるかチェックします
TEST_P(FsAesCtrStorageTest, TestMultiThreadReadWriteWithPooledBuffer)
{
    static const size_t StackSize = 32 * 1024;
    static const size_t ThreadCount = 32;
    NN_OS_ALIGNAS_THREAD_STACK static char s_ThreadStack[StackSize * ThreadCount] = {};

    FormatTestParam param = GetParam();
    size_t baseStorageSize = static_cast<size_t>(param.baseStorageSize);

    GetStorage()->SetChecked(false);

    struct ThreadArgument
    {
        int id;
        nn::fs::IStorage* pStorage;
        size_t size;
    };

    static nn::os::ThreadType threads[ThreadCount] = {};
    static ThreadArgument argument[ThreadCount] = {};

    std::unique_ptr<char[]> writeBuffer(new char[baseStorageSize]);
    std::unique_ptr<char[]> readBuffer(new char[baseStorageSize]);
    nnt::fs::util::FillBufferWithRandomValue(writeBuffer.get(), baseStorageSize);

    auto ThreadFunction = [](void* ptr) NN_NOEXCEPT
    {
        ThreadArgument* pArgument = reinterpret_cast<ThreadArgument*>(ptr);
        int64_t offset = pArgument->id * pArgument->size;

        nn::fssystem::PooledBuffer poolBuffer;
        poolBuffer.Allocate(pArgument->size, pArgument->size);

        std::unique_ptr<char[]> writeBuffer(new char[pArgument->size]);
        std::unique_ptr<char[]> readBuffer(new char[pArgument->size]);

        // Write し Read する処理を適当な回数繰り返す
        int loopCount = 30;
        for( auto i = 0; i < loopCount; ++i )
        {
            nnt::fs::util::FillBufferWithRandomValue(writeBuffer.get(), pArgument->size);

            // wirteBuffer の中身を poolBuffer に写し、poolBuffer を pStorage に書き込む
            memcpy(poolBuffer.GetBuffer(), writeBuffer.get(), pArgument->size);
            NNT_ASSERT_RESULT_SUCCESS(
                pArgument->pStorage->Write(offset, poolBuffer.GetBuffer(), pArgument->size)
            );
            NNT_ASSERT_RESULT_SUCCESS(
                pArgument->pStorage->Read(offset, readBuffer.get(), pArgument->size)
            );

            // 読み込んだ値が書き込んだ値と同じかチェック
            ASSERT_EQ(0, memcmp(writeBuffer.get(), readBuffer.get(), pArgument->size));
        }
    };

    for( int i = 0; i < ThreadCount; ++i )
    {
        argument[i].id = i;
        argument[i].pStorage = GetStorage();
        argument[i].size = baseStorageSize / ThreadCount;

        NNT_ASSERT_RESULT_SUCCESS(
            nn::os::CreateThread(
                &threads[i],
                ThreadFunction,
                &argument[i],
                s_ThreadStack + StackSize * i,
                StackSize,
                nn::os::DefaultThreadPriority - i
            )
        );
    }

    // スレッド開始
    for( auto& thread : threads )
    {
        nn::os::StartThread(&thread);
    }

    // スレッド終了
    for( auto& thread : threads )
    {
        nn::os::WaitThread(&thread);
        nn::os::DestroyThread(&thread);
    }
}

// QueryRange をテストします
TEST_P(FsAesCtrStorageTest, QueryRange)
{
    FormatTestParam param = GetParam();

    nn::fs::QueryRangeInfo info;

    NNT_ASSERT_RESULT_FAILURE(nn::fs::ResultNullptrArgument,
        GetStorage()->OperateRange(nullptr, sizeof(info), nn::fs::OperationId::QueryRange, 0, param.baseStorageSize, nullptr, 0));
    NNT_ASSERT_RESULT_FAILURE(nn::fs::ResultInvalidSize,
        GetStorage()->OperateRange(&info, 0, nn::fs::OperationId::QueryRange, 0, param.baseStorageSize, nullptr, 0));

    NNT_ASSERT_RESULT_SUCCESS(GetStorage()->OperateRange(
        &info, sizeof(info), nn::fs::OperationId::QueryRange, 0, param.baseStorageSize, nullptr, 0));
    EXPECT_EQ(static_cast<int32_t>(nn::fs::AesCtrKeyTypeFlag::InternalKeyForSoftwareAes), info.aesCtrKeyTypeFlag);
    EXPECT_EQ(0, info.speedEmulationTypeFlag);
}

// 4 GB 以上のオフセットの読み込み、書き込みのテストをします。
TEST_P(FsAesCtrStorageLargeSizeTest, TestReadWrite)
{
    TestWriteReadStorageWithLargeOffset(GetStorage(), GetParam().blockSize);
}

// PooledBuffer を使った 4 GB 以上のオフセットの書き込みの境界値テストをします。
TEST_P(FsAesCtrStorageLargeSizeTest, TestWritePooledBuffer)
{
    const size_t bufferSize = GetParam().blockSize;
    nn::fssystem::PooledBuffer pooledBuffer;
    pooledBuffer.Allocate(bufferSize, bufferSize);

    TestWriteReadStorageWithLargeOffset(GetStorage(), bufferSize, pooledBuffer.GetBuffer());
}

// 4 GB 以上のサイズのQueryRange をテストします
TEST_P(FsAesCtrStorageLargeSizeTest, QueryRange)
{
    FormatTestParam param = GetParam();

    nn::fs::QueryRangeInfo info;

    NNT_ASSERT_RESULT_FAILURE(nn::fs::ResultNullptrArgument,
        GetStorage()->OperateRange(nullptr, sizeof(info), nn::fs::OperationId::QueryRange, 0, param.baseStorageSize, nullptr, 0));
    NNT_ASSERT_RESULT_FAILURE(nn::fs::ResultInvalidSize,
        GetStorage()->OperateRange(&info, 0, nn::fs::OperationId::QueryRange, 0, param.baseStorageSize, nullptr, 0));

    NNT_ASSERT_RESULT_SUCCESS(GetStorage()->OperateRange(
        &info, sizeof(info), nn::fs::OperationId::QueryRange, 0, param.baseStorageSize, nullptr, 0));
    EXPECT_EQ(static_cast<int32_t>(nn::fs::AesCtrKeyTypeFlag::InternalKeyForSoftwareAes), info.aesCtrKeyTypeFlag);
    EXPECT_EQ(0, info.speedEmulationTypeFlag);
}

INSTANTIATE_TEST_CASE_P(
    AesCtrStorageTest,
    FsAesCtrStorageTest,
    ::testing::ValuesIn(MakeTestParam())
);

INSTANTIATE_TEST_CASE_P(
    AesCtrStorageBoundaryTest,
    FsAesCtrStorageLargeSizeTest,
    ::testing::ValuesIn(MakeLargeSizeTestParam())
);


