﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/nn_Common.h>
#include <nn/nn_Result.h>
#include <nn/fs.h>
#include <nn/result/result_HandlingUtility.h>
#include <nn/nn_Log.h>
#include <nn/nn_Assert.h>

#include <nnt/nntest.h>
#include <nnt/base/testBase_Exit.h>
#include <nnt/fsUtil/testFs_util.h>
#include <nnt/result/testResult_Assert.h>
#include <nnt/nnt_Argument.h>

#include <nn/fssystem/fs_AllocatorUtility.h>
#include "detail/fssrv_SaveDataTransferStream.h"
#include <nn/util/util_Optional.h>
#include <nn/crypto/crypto_Sha256Generator.h>

using namespace nn;
using namespace nn::fs;
using namespace nn::fs::detail;
using namespace nnt::fs::util;

using namespace nn::fssystem;
using namespace nn::fssrv::detail;

namespace {

class MemoryStream : public IStream
{
public:
    explicit MemoryStream(size_t bufferSize)
        : m_BufferSize(bufferSize)
        , m_Buffer(AllocateBuffer(bufferSize))
        , m_MemoryStorage(AllocateShared<MemoryStorage>(m_Buffer.get(), bufferSize))
        , m_StorageStream(m_MemoryStorage)
    {
    }

    char* GetBuffer()
    {
        return m_Buffer.get();
    }

    size_t GetBufferSize()
    {
        return m_BufferSize;
    }

public:

    virtual Result Pull(size_t* outValue, char* buffer, size_t size) NN_NOEXCEPT NN_OVERRIDE
    {
        return m_StorageStream.Pull(outValue, buffer, size);
    }

    virtual bool IsEnd() NN_NOEXCEPT NN_OVERRIDE
    {
        return m_StorageStream.IsEnd();
    }

    virtual Result GetRestRawDataSize(int64_t* outValue) NN_NOEXCEPT NN_OVERRIDE
    {
        return m_StorageStream.GetRestRawDataSize(outValue);
    }

    virtual Result Push(const char* buffer, size_t size) NN_NOEXCEPT NN_OVERRIDE
    {
        return m_StorageStream.Push(buffer, size);
    }

private:
    size_t m_BufferSize;
    decltype(AllocateBuffer(0)) m_Buffer;
    std::shared_ptr<MemoryStorage> m_MemoryStorage;
    StorageStream m_StorageStream;
};

// 最下層バッファへのアクセサも持たせた Sink のラッパ
class MockSink : public ISink
{
public:
    MockSink(std::shared_ptr<ISink> sink, std::shared_ptr<MemoryStream> memoryStoream)
        : m_BaseSink(std::move(sink))
        , m_Buffer(memoryStoream->GetBuffer())
        , m_BufferSize(memoryStoream->GetBufferSize())
    {
    }

    char* GetBuffer()
    {
        return m_Buffer;
    }

    size_t GetBufferSize()
    {
        return m_BufferSize;
    }

public:
    virtual Result Push(const char* buffer, size_t size) NN_NOEXCEPT NN_OVERRIDE
    {
        return m_BaseSink->Push(buffer, size);
    }


private:
    std::shared_ptr<ISink> m_BaseSink;
    char* m_Buffer;
    size_t m_BufferSize;
};


// TODO:
// 鍵、IV が反映されること

class Encryptor : public AesGcmSource::IEncryptor, public fs::detail::Newable
{
public:
    Encryptor() NN_NOEXCEPT
    {
    }

public:
    virtual void   Initialize(AesGcmStreamHeader* pOutHeader) NN_NOEXCEPT NN_OVERRIDE
    {
        AesGcmStreamHeader header;
        memset(header.iv, 0, sizeof(header.iv));
        header.iv[0] = 0x0A;
        header.keyGeneration = 1;

        char key[decltype(m_Encryptor)::BlockSize];
        memset(key, 0x0B, sizeof(key));
        key[0] = static_cast<char>(header.keyGeneration);

        m_Encryptor.Initialize(key, sizeof(key), header.iv, sizeof(header.iv));
        memcpy(pOutHeader, &header, sizeof(AesGcmStreamHeader));
    }

    virtual size_t Update(void* pDst, size_t dstSize, const void* pSrc, size_t srcSize) NN_NOEXCEPT NN_OVERRIDE
    {
        return m_Encryptor.Update(pDst, dstSize, pSrc, srcSize);
    }

    virtual void   GetMac(void* pMac, size_t macSize) NN_NOEXCEPT NN_OVERRIDE
    {
        return m_Encryptor.GetMac(pMac, macSize);
    }

private:

    crypto::Aes128GcmEncryptor m_Encryptor;
};

class Decryptor : public AesGcmSink::IDecryptor, public fs::detail::Newable
{
public:
    Decryptor() NN_NOEXCEPT
    {
    }

public:
    virtual void   Initialize(const AesGcmStreamHeader& header) NN_NOEXCEPT NN_OVERRIDE
    {
        char key[decltype(m_Decryptor)::BlockSize];
        memset(key, 0x0B, sizeof(key));
        key[0] = static_cast<char>(header.keyGeneration);

        m_Decryptor.Initialize(key, sizeof(key), header.iv, sizeof(header.iv));
    }

    virtual size_t Update(void* pDst, size_t dstSize, const void* pSrc, size_t srcSize) NN_NOEXCEPT NN_OVERRIDE
    {
        return m_Decryptor.Update(pDst, dstSize, pSrc, srcSize);
    }

    virtual void   GetMac(void* pMac, size_t macSize) NN_NOEXCEPT NN_OVERRIDE
    {
        return m_Decryptor.GetMac(pMac, macSize);
    }

private:
    crypto::Aes128GcmDecryptor m_Decryptor;
};


// 様々なバッファサイズで pull/push して結果が変わらないこと
template <int Size, typename CreateSourceFunction, typename CreateSinkFunction>
void PullAndPushByVariousSize(CreateSourceFunction CreateSource, CreateSinkFunction CreateSink)
{
    struct Hash
    {
        char value[32];
    };

    size_t BufferSizeArray[] =
    {
        1,
        2,
        3,
        sizeof(AesGcmStreamTail) - 1,
        sizeof(AesGcmStreamTail),
        sizeof(AesGcmStreamTail) + 1,

        sizeof(AesGcmStreamHeader) - 1,
        sizeof(AesGcmStreamHeader),
        sizeof(AesGcmStreamHeader) + 1,

        sizeof(AesGcmStreamHeader) + Size - 1,
        sizeof(AesGcmStreamHeader) + Size,
        sizeof(AesGcmStreamHeader) + Size + 1,

        sizeof(AesGcmStreamHeader) + Size + sizeof(AesGcmStreamTail) - 1,
        sizeof(AesGcmStreamHeader) + Size + sizeof(AesGcmStreamTail),
        sizeof(AesGcmStreamHeader) + Size + sizeof(AesGcmStreamTail) + 1,

        1024 * 1024,
    };

    // pull
    {
        util::optional<Hash> masterHashPulled;
        for (auto BufferSize : BufferSizeArray)
        {
            NN_LOG("BufferSize: %d\n", BufferSize);

            auto pSource = CreateSource(Size);
            crypto::Sha256Generator sha;
            sha.Initialize();
            auto buffer = AllocateBuffer(BufferSize);

            int64_t lastRestSize;
            NNT_EXPECT_RESULT_SUCCESS(pSource->GetRestRawDataSize(&lastRestSize));
            EXPECT_GT(lastRestSize, 0);

            while (true)
            {
                size_t pulledSize;
                NNT_EXPECT_RESULT_SUCCESS(pSource->Pull(&pulledSize, buffer.get(), BufferSize));
                sha.Update(buffer.get(), pulledSize);
                // DumpBuffer(buffer.get(), pulledSize);

                // 単調減少
                int64_t restSize;
                NNT_EXPECT_RESULT_SUCCESS(pSource->GetRestRawDataSize(&restSize));
                EXPECT_LE(restSize, lastRestSize);
                lastRestSize = restSize;

                if (pulledSize < BufferSize) // 完了
                {
                    EXPECT_TRUE(pSource->IsEnd());
                    EXPECT_EQ(0, lastRestSize);

                    Hash hash;
                    sha.GetHash(&hash.value, sizeof(hash.value));

                    if (masterHashPulled == util::nullopt)
                    {
                        masterHashPulled.emplace();
                        memcpy(masterHashPulled->value, hash.value, sizeof(hash.value));
                    }
                    else
                    {
                        // 他のバッファサイズで pull したときと一致すること
                        NNT_FS_UTIL_EXPECT_MEMCMPEQ(masterHashPulled->value, hash.value, sizeof(hash.value));
                    }

                    break;
                }
                else
                {
                    if (pSource->IsEnd())
                    {
                        // 未完了なのに IsEnd() でないこと
                        size_t unexpectedPulledSize;
                        NNT_EXPECT_RESULT_SUCCESS(pSource->Pull(&unexpectedPulledSize, buffer.get(), BufferSize));
                        EXPECT_EQ(0, unexpectedPulledSize);
                    }
                }
            }
        }
    }

    const size_t PullBufferSize = 1024 * 1024;
    auto pulledBuffer = AllocateBuffer(PullBufferSize);
    ASSERT_NE(nullptr, pulledBuffer.get());

    size_t pulledSize;
    {
        auto pSource = CreateSource(Size);
        NNT_EXPECT_RESULT_SUCCESS(pSource->Pull(&pulledSize, pulledBuffer.get(), PullBufferSize));
        EXPECT_TRUE(pSource->IsEnd());
    }

    // push
    {
        util::optional<Hash> masterHashPushed;

        for (auto BufferSize : BufferSizeArray)
        {
            NN_LOG("BufferSize: %d\n", BufferSize);

            auto pSink = CreateSink(Size);
            int64_t restSize = pulledSize;
            int64_t offset = 0;
            while (restSize > 0)
            {
                auto pushSize = std::min(static_cast<size_t>(restSize), BufferSize);
                NNT_EXPECT_RESULT_SUCCESS(pSink->Push(pulledBuffer.get() + offset, pushSize));

                offset += pushSize;
                restSize -= pushSize;
            }

            Hash hash;
            crypto::GenerateSha256Hash(&hash.value, sizeof(hash.value), pSink->GetBuffer(), pSink->GetBufferSize());

            if (masterHashPushed == util::nullopt)
            {
                masterHashPushed.emplace();
                memcpy(masterHashPushed->value, hash.value, sizeof(hash.value));
            }
            else
            {
                // 他のバッファサイズで push したときと一致すること
                NNT_FS_UTIL_EXPECT_MEMCMPEQ(masterHashPushed->value, hash.value, sizeof(hash.value));
            }
        }
    }

} // NOLINT(impl/function_size)

TEST(CompressionStream, PullAndPushByVariousSize)
{
    auto CreateStream = [&](size_t size)
    {
        auto memoryStreamSrc = AllocateShared<MemoryStream>(size);
        auto source = AllocateShared<CompressionSource>(memoryStreamSrc);
        NNT_EXPECT_RESULT_SUCCESS(source->Initialize());
        FillBufferWith32BitCount(memoryStreamSrc->GetBuffer(), size, 0);
        return source;
    };

    auto CreateSink = [&](size_t size)
    {
        auto memoryStreamDst = AllocateShared<MemoryStream>(size);
        auto sink = AllocateShared<DecompressionSink>(memoryStreamDst);
        NNT_EXPECT_RESULT_SUCCESS(sink->Initialize());
        auto mock = AllocateShared<MockSink>(sink, memoryStreamDst);
        return mock;
    };

    PullAndPushByVariousSize<1024>(CreateStream, CreateSink);
}


TEST(AesGcmStream, PullAndPushByVariousSize)
{
    auto CreateStream = [&](size_t size)
    {
        auto memoryStreamSrc = AllocateShared<MemoryStream>(size);
        auto encryptor = AllocateShared<Encryptor>();
        auto aesGcmSource = AllocateShared<AesGcmSource>(memoryStreamSrc, std::move(encryptor));
        FillBufferWith32BitCount(memoryStreamSrc->GetBuffer(), size, 0);

        return aesGcmSource;
    };

    auto CreateSink = [&](size_t size)
    {
        auto memoryStreamDst = AllocateShared<MemoryStream>(size);
        auto decryptor = AllocateShared<Decryptor>();
        auto aesGcmSink = AllocateShared<AesGcmSink>(memoryStreamDst, size, std::move(decryptor));

        auto mock = AllocateShared<MockSink>(aesGcmSink, memoryStreamDst);
        return mock;
    };

    PullAndPushByVariousSize<1024>(CreateStream, CreateSink);
}




// AES-GCM で改竄検知できること
TEST(AesGcmStream, MacVerification)
{
    const size_t BufferSize = 32;

    // src source
    auto memoryStreamSrc = AllocateShared<MemoryStream>(BufferSize);
    auto encryptor = AllocateShared<Encryptor>();
    AesGcmSource aesGcmSource(memoryStreamSrc, std::move(encryptor));
    FillBufferWith32BitCount(memoryStreamSrc->GetBuffer(), BufferSize, 0);

    // pull する
    const size_t PullBufferSize = 128;
    auto pulledBuffer = AllocateBuffer(PullBufferSize);
    size_t pulledSize;
    NNT_EXPECT_RESULT_SUCCESS(aesGcmSource.Pull(&pulledSize, pulledBuffer.get(), PullBufferSize));
    EXPECT_EQ(pulledSize, 32 + 32 + 16);
    EXPECT_TRUE(aesGcmSource.IsEnd());

    // 改竄検知
    {
        auto memoryStreamDst = AllocateShared<MemoryStream>(BufferSize);
        auto decryptor = AllocateShared<Decryptor>();
        AesGcmSink aesGcmSink(memoryStreamDst, BufferSize, std::move(decryptor));

        pulledBuffer.get()[pulledSize - 1] ^= 0xFF; // mac 改竄

        NNT_EXPECT_RESULT_SUCCESS(aesGcmSink.Push(pulledBuffer.get(), pulledSize));
        NNT_EXPECT_RESULT_FAILURE(ResultSaveDataTransferImportMacVerificationFailed, aesGcmSink.Finalize());

        pulledBuffer.get()[pulledSize - 1] ^= 0xFF;
    }
    {
        auto memoryStreamDst = AllocateShared<MemoryStream>(BufferSize);
        auto decryptor = AllocateShared<Decryptor>();
        AesGcmSink aesGcmSink(memoryStreamDst, BufferSize, std::move(decryptor));

        pulledBuffer.get()[32] ^= 0xFF; // データ改竄

        NNT_EXPECT_RESULT_SUCCESS(aesGcmSink.Push(pulledBuffer.get(), pulledSize));
        NNT_EXPECT_RESULT_FAILURE(ResultSaveDataTransferImportMacVerificationFailed, aesGcmSink.Finalize());

        pulledBuffer.get()[32] ^= 0xFF;
    }

    // dst sink
    {
        auto memoryStreamDst = AllocateShared<MemoryStream>(BufferSize);
        auto decryptor = AllocateShared<Decryptor>();
        AesGcmSink aesGcmSink(memoryStreamDst, BufferSize, std::move(decryptor));

        NNT_EXPECT_RESULT_SUCCESS(aesGcmSink.Push(pulledBuffer.get(), pulledSize));

        // mac 検証
        NNT_EXPECT_RESULT_SUCCESS(aesGcmSink.Finalize());

        // 元のデータが得られること
        EXPECT_TRUE(IsFilledWith32BitCount(memoryStreamDst->GetBuffer(), BufferSize, 0));
    }
}


// 圧縮で小さくなること、正しく伸長できること
TEST(CompressionStream, Compression)
{
    const size_t BufferSize = 64 * 1024;

    // src source
    auto memoryStreamSrc = AllocateShared<MemoryStream>(BufferSize);
    CompressionSource source(memoryStreamSrc);
    NNT_EXPECT_RESULT_SUCCESS(source.Initialize());
    FillBufferWith32BitCount(memoryStreamSrc->GetBuffer(), BufferSize, 0);

    // pull する
    const size_t PullBufferSize = BufferSize;
    auto pulledBuffer = AllocateBuffer(PullBufferSize);
    size_t pulledSize;
    NNT_EXPECT_RESULT_SUCCESS(source.Pull(&pulledSize, pulledBuffer.get(), PullBufferSize));
    EXPECT_LT(pulledSize, static_cast<size_t>(32 * 1024)); // 小さくなっていること (64KB -> 32KB↓)
    EXPECT_TRUE(source.IsEnd());

    // dst sink
    {
        auto memoryStreamDst = AllocateShared<MemoryStream>(BufferSize);
        DecompressionSink sink(memoryStreamDst);
        NNT_EXPECT_RESULT_SUCCESS(sink.Initialize());

        NNT_EXPECT_RESULT_SUCCESS(sink.Push(pulledBuffer.get(), pulledSize));

        // 元のデータが得られること
        EXPECT_TRUE(IsFilledWith32BitCount(memoryStreamDst->GetBuffer(), BufferSize, 0));
    }
}


// 極端に小さくなるデータの場合でも問題ないこと
TEST(CompressionStream, ExtremelyCompressibleData)
{
    const size_t BufferSize = 64 * 1024 * 1024;
    const unsigned char FillValue = 0xAB;

    // src source
    auto memoryStreamSrc = AllocateShared<MemoryStream>(BufferSize);
    ASSERT_NE(nullptr, memoryStreamSrc);
    CompressionSource source(memoryStreamSrc);
    NNT_EXPECT_RESULT_SUCCESS(source.Initialize());
    memset(memoryStreamSrc->GetBuffer(), FillValue, BufferSize);

    // pull する
    const size_t PullBufferSize = BufferSize;
    auto pulledBuffer = AllocateBuffer(PullBufferSize);
    size_t pulledSize;
    NNT_EXPECT_RESULT_SUCCESS(source.Pull(&pulledSize, pulledBuffer.get(), PullBufferSize));
    EXPECT_LT(pulledSize, static_cast<size_t>(128 * 1024)); // 小さくなっていること (64MB -> 128KB↓)
    EXPECT_TRUE(source.IsEnd());

    // dst sink
    {
        auto memoryStreamDst = AllocateShared<MemoryStream>(BufferSize);
        ASSERT_NE(nullptr, memoryStreamDst);
        DecompressionSink sink(memoryStreamDst);
        NNT_EXPECT_RESULT_SUCCESS(sink.Initialize());

        NNT_EXPECT_RESULT_SUCCESS(sink.Push(pulledBuffer.get(), pulledSize));

        // 元のデータが得られること
        EXPECT_TRUE(IsFilledWithValue(memoryStreamDst->GetBuffer(), BufferSize, FillValue));
    }
}


// 圧縮が効かず大きくなるデータの場合でも問題ないこと
TEST(CompressionStream, ExpansiveData)
{
    const size_t BufferSize = 32;

    // src source
    auto memoryStreamSrc = AllocateShared<MemoryStream>(BufferSize);
    ASSERT_NE(nullptr, memoryStreamSrc);
    CompressionSource source(memoryStreamSrc);
    NNT_EXPECT_RESULT_SUCCESS(source.Initialize());
    FillBufferWith8BitCount(memoryStreamSrc->GetBuffer(), BufferSize, 0);

    // pull する
    const size_t PullBufferSize = BufferSize * 2;
    auto pulledBuffer = AllocateBuffer(PullBufferSize);
    size_t totalPulledSize = 0;
    {
        size_t pulledSize;
        NNT_EXPECT_RESULT_SUCCESS(source.Pull(&pulledSize, pulledBuffer.get(), BufferSize));
        EXPECT_EQ(pulledSize, 32);
        EXPECT_TRUE(!source.IsEnd()); // 未完
        totalPulledSize += pulledSize;
    }
    {
        size_t pulledSize;
        NNT_EXPECT_RESULT_SUCCESS(source.Pull(&pulledSize, pulledBuffer.get() + totalPulledSize, BufferSize));
        EXPECT_GT(pulledSize, static_cast<size_t>(0));
        EXPECT_TRUE(source.IsEnd());
        totalPulledSize += pulledSize;
    }

    EXPECT_GT(totalPulledSize, static_cast<size_t>(32)); // 大きくなること

    // dst sink
    {
        auto memoryStreamDst = AllocateShared<MemoryStream>(BufferSize);
        DecompressionSink sink(memoryStreamDst);
        NNT_EXPECT_RESULT_SUCCESS(sink.Initialize());

        NNT_EXPECT_RESULT_SUCCESS(sink.Push(pulledBuffer.get(), totalPulledSize));

        nnt::fs::util::DumpBufferDiff(memoryStreamDst->GetBuffer(), memoryStreamSrc->GetBuffer(), BufferSize);

        // 元のデータが得られること
        EXPECT_TRUE(IsFilledWith8BitCount(memoryStreamDst->GetBuffer(), BufferSize, 0));
    }
}


}

extern "C" void nnMain()
{
    int     argc = nnt::GetHostArgc();
    char**  argv = nnt::GetHostArgv();

    ::testing::InitGoogleTest(&argc, argv);

    nn::fssystem::InitializeAllocator(nnt::fs::util::Allocate, nnt::fs::util::Deallocate);
    nn::fs::SetAllocator(nnt::fs::util::Allocate, nnt::fs::util::Deallocate);

    static const size_t BufferPoolSize = 8 * 1024 * 1024;
    static NN_ALIGNAS(4096) char s_BufferPool[BufferPoolSize];
    nn::fssystem::InitializeBufferPool(s_BufferPool, BufferPoolSize);

    nnt::fs::util::ResetAllocateCount();

    auto result = RUN_ALL_TESTS();

    if (nnt::fs::util::CheckMemoryLeak())
    {
        nnt::Exit(1);
    }

    nnt::Exit(result);
}

