﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#pragma once

#include <nn/nn_Common.h>
#include <nn/fs/fs_Result.h>
#include <nn/fs/detail/fs_Newable.h>
#include <nn/util/util_Compression.h>
#include <nn/util/util_StreamingCompression.h>
#include <nn/util/util_Decompression.h>
#include <nn/util/util_StreamingDecompression.h>
#include <nn/util/util_BinTypes.h>
#include <nn/crypto/crypto_Sha256Generator.h>
#include <nn/crypto/crypto_Aes128GcmEncryptor.h>
#include <nn/crypto/crypto_Aes128GcmDecryptor.h>
#include <nn/crypto/crypto_Csrng.h>
#include <nn/crypto/crypto_Compare.h>
#include <nn/fssystem/fs_AsynchronousAccess.h>
#include <nn/fs/fs_SaveDataTransferVersion2.h>

namespace nn { namespace fssrv { namespace detail {

class ISource
{
public:
    virtual ~ISource() NN_NOEXCEPT {}
    virtual Result Pull(size_t* outValue, char* buffer, size_t size) NN_NOEXCEPT = 0;
    virtual Result GetRestRawDataSize(int64_t* outValue) NN_NOEXCEPT = 0;
    virtual bool IsEnd() NN_NOEXCEPT = 0;
};

class ISink
{
public:
    virtual ~ISink() NN_NOEXCEPT {}
    virtual Result Push(const char* buffer, size_t size) NN_NOEXCEPT = 0;
    // TODO: SetSize()
};

class IStream : public ISource, public ISink
{
public:
    virtual ~IStream() NN_NOEXCEPT {}
};



class StorageStream : public IStream
{
public:
    explicit StorageStream(std::shared_ptr<fs::IStorage> storage) NN_NOEXCEPT
        : m_Storage(std::move(storage))
    {
        NN_ABORT_UNLESS_RESULT_SUCCESS(m_Storage->GetSize(&m_RestStorageSize));
    }

public:

    virtual Result Pull(size_t* outValue, char* buffer, size_t size) NN_NOEXCEPT NN_OVERRIDE
    {
        size_t readSize = static_cast<size_t>(std::min(static_cast<int64_t>(size), m_RestStorageSize));
        NN_RESULT_DO(m_Storage->Read(m_Offset, buffer, readSize));

        m_Offset += readSize;
        m_RestStorageSize -= readSize;

        *outValue = static_cast<int64_t>(readSize);
        NN_RESULT_SUCCESS;
    }

    virtual Result GetRestRawDataSize(int64_t* outValue) NN_NOEXCEPT NN_OVERRIDE
    {
        *outValue = m_RestStorageSize;
        NN_RESULT_SUCCESS;
    }

    virtual bool IsEnd() NN_NOEXCEPT NN_OVERRIDE
    {
        return m_RestStorageSize == 0;
    }

    virtual Result Push(const char* buffer, size_t size) NN_NOEXCEPT NN_OVERRIDE
    {
        size_t writeSize = static_cast<size_t>(size);
        NN_RESULT_DO(m_Storage->Write(m_Offset, buffer, writeSize));
        m_Offset += writeSize;
        NN_RESULT_SUCCESS;
    }

private:

    std::shared_ptr<fs::IStorage> m_Storage;
    int64_t m_RestStorageSize;
    int64_t m_Offset = 0;
};

class CompressionSource : public ISource
{
private:
    static const size_t CompressionWorkBufferSize = 16 * 1024;

private:
    bool IsInitialized()
    {
        return m_WorkBufferForContext.GetBuffer() != nullptr;
    }


public:
    explicit CompressionSource(std::shared_ptr<ISource> source) NN_NOEXCEPT
        : m_Source(std::move(source))
    {
    }

    Result Initialize()
    {
        m_WorkBufferForContext.Allocate(util::CompressDeflateWorkBufferSizeDefault, util::CompressDeflateWorkBufferSizeDefault); // TORIAEZU: デフォルトサイズ
        m_RawBuffer.Allocate(CompressionWorkBufferSize, CompressionWorkBufferSize); // 必要最小限サイズのみ取る＆取れるまで待つ

        NN_RESULT_THROW_UNLESS(m_WorkBufferForContext.GetSize() >= util::CompressDeflateWorkBufferSizeDefault, fs::ResultAllocationMemoryFailed());
        NN_RESULT_THROW_UNLESS(m_RawBuffer.GetSize() >= CompressionWorkBufferSize, fs::ResultAllocationMemoryFailed());

        util::InitializeStreamingCompressDeflateContext(&m_CompressionContext, m_WorkBufferForContext.GetBuffer(), m_WorkBufferForContext.GetSize());

        NN_RESULT_SUCCESS;
    }

    void Finalize()
    {
        m_WorkBufferForContext.Deallocate();
        m_RawBuffer.Deallocate();
    }

public:
    virtual Result Pull(size_t* outValue, char* buffer, size_t size) NN_NOEXCEPT NN_OVERRIDE
    {
        NN_RESULT_THROW_UNLESS(IsInitialized(), fs::ResultPreconditionViolation());

        size_t bufferOffset = 0;
        size_t restDstSize = size;

        while (restDstSize > 0)
        {
            // 空きを埋める
            size_t readSize;
            NN_RESULT_DO(m_Source->Pull(&readSize, m_RawBuffer.GetBuffer() + m_RawBufferValidDataSize, m_RawBuffer.GetSize() - m_RawBufferValidDataSize));
            m_RawBufferValidDataSize += readSize;

            size_t compressedSize;
            size_t consumedSrcSize;
            util::StreamingCompressDeflate(&compressedSize, &consumedSrcSize, buffer + bufferOffset, restDstSize, m_RawBuffer.GetBuffer(), m_RawBufferValidDataSize, &m_CompressionContext);

            if (compressedSize == 0 && consumedSrcSize == 0)
            {
                // 圧縮完了
                m_IsCompressionComplete = true;
                break;
            }

            NN_SDK_ASSERT(consumedSrcSize <= m_RawBufferValidDataSize);
            memmove(m_RawBuffer.GetBuffer(), m_RawBuffer.GetBuffer() + consumedSrcSize, m_RawBufferValidDataSize - consumedSrcSize);
            m_RawBufferValidDataSize -= consumedSrcSize;

            bufferOffset += compressedSize;
            restDstSize -= compressedSize;
        }

        *outValue = bufferOffset;
        NN_RESULT_SUCCESS;
    }

    virtual Result GetRestRawDataSize(int64_t* outValue) NN_NOEXCEPT NN_OVERRIDE
    {
        return m_Source->GetRestRawDataSize(outValue);
    }

    virtual bool IsEnd() NN_NOEXCEPT NN_OVERRIDE
    {
        return m_IsCompressionComplete && m_RawBufferValidDataSize == 0 && m_Source->IsEnd();
    }

private:
    std::shared_ptr<ISource> m_Source;
    util::StreamingCompressDeflateContext m_CompressionContext;

    fssystem::PooledBuffer m_WorkBufferForContext;
    fssystem::PooledBuffer m_RawBuffer;
    size_t m_RawBufferValidDataSize = 0;
    bool m_IsCompressionComplete = false;
};

class DecompressionSink : public ISink
{
private:
    static const size_t DecompressionWorkBufferSize = 16 * 1024;

private:
    bool IsInitialized()
    {
        return m_WorkBufferForContext.GetBuffer() != nullptr;
    }


public:
    explicit DecompressionSink(std::shared_ptr<ISink> sink) NN_NOEXCEPT
        : m_Sink(std::move(sink))
    {
    }

    Result Initialize()
    {
        m_WorkBufferForContext.Allocate(util::CompressDeflateWorkBufferSizeDefault, util::CompressDeflateWorkBufferSizeDefault); // TORIAEZU: デフォルトサイズ
        m_RawBuffer.Allocate(DecompressionWorkBufferSize, DecompressionWorkBufferSize); // 必要最小限サイズのみ取る＆取れるまで待つ

        NN_RESULT_THROW_UNLESS(m_WorkBufferForContext.GetSize() >= util::CompressDeflateWorkBufferSizeDefault, fs::ResultAllocationMemoryFailed());
        NN_RESULT_THROW_UNLESS(m_RawBuffer.GetSize() >= DecompressionWorkBufferSize, fs::ResultAllocationMemoryFailed());

        util::InitializeStreamingDecompressDeflateContext(&m_DecompressionContext); // , m_WorkBufferForContext.GetBuffer(), m_WorkBufferForContext.GetSize());

        NN_RESULT_SUCCESS;
    }

    void Finalize()
    {
        m_WorkBufferForContext.Deallocate();
        m_RawBuffer.Deallocate();
    }

public:

    virtual Result Push(const char* buffer, size_t size) NN_NOEXCEPT NN_OVERRIDE
    {
        NN_RESULT_THROW_UNLESS(IsInitialized(), fs::ResultPreconditionViolation());

        size_t bufferOffset = 0;
        size_t restSrcSize = size;

        while (restSrcSize > 0)
        {
            // rawbuffer を埋める
            size_t decompressedSize;
            size_t consumedSrcSize;
            bool isDecompressionSuccessful = util::StreamingDecompressDeflate(&decompressedSize, &consumedSrcSize, m_RawBuffer.GetBuffer(), m_RawBuffer.GetSize(), buffer + bufferOffset, restSrcSize, &m_DecompressionContext);

            if (!isDecompressionSuccessful)
            {
                return fs::ResultSaveDataTransferImportDataDecompressionFailed();
            }

            if (decompressedSize == 0 && consumedSrcSize == 0)
            {
                // buffer の展開が完了 = push 完了
                break;
            }

            // 展開したデータを push
            NN_RESULT_DO(m_Sink->Push(m_RawBuffer.GetBuffer(), decompressedSize));

            bufferOffset += consumedSrcSize;
            restSrcSize -= consumedSrcSize;
        }

        NN_RESULT_SUCCESS;
    }


private:

    std::shared_ptr<ISink> m_Sink;
    util::StreamingDecompressDeflateContext m_DecompressionContext;

    fssystem::PooledBuffer m_WorkBufferForContext;
    fssystem::PooledBuffer m_RawBuffer;
};

struct AesGcmStreamHeader
{
    static const uint32_t Signature = NN_UTIL_CREATE_SIGNATURE_4('A', 'G', 'S', '0');

public:
    uint32_t signature = Signature;
    int16_t  version = 0;
    int16_t  keyGeneration = 0;
    char     reserved[8] = { 0 };
    char     iv[16] = { 0 };
};
NN_STATIC_ASSERT(sizeof(AesGcmStreamHeader) == 32);
#if !defined(NN_BUILD_CONFIG_COMPILER_GCC)
NN_STATIC_ASSERT(std::is_trivially_copyable<AesGcmStreamHeader>::value);
#endif

struct AesGcmStreamTail
{
public:
    char mac[16] = { 0 };
};
NN_STATIC_ASSERT(sizeof(AesGcmStreamTail) == 16);
#if !defined(NN_BUILD_CONFIG_COMPILER_GCC)
NN_STATIC_ASSERT(std::is_trivially_copyable<AesGcmStreamTail>::value);
#endif

class AesGcmSource : public ISource
{
public:
    static const size_t MacSize = crypto::Aes128GcmEncryptor::MacSize;
    static const size_t BlockSize = crypto::Aes128GcmEncryptor::BlockSize;

public:
    class IEncryptor
    {
    public:
        virtual ~IEncryptor() NN_NOEXCEPT {}
        virtual void   Initialize(AesGcmStreamHeader* pOutHeader) NN_NOEXCEPT = 0;
        virtual size_t Update(void* pDst, size_t dstSize, const void* pSrc, size_t srcSize) NN_NOEXCEPT = 0;
        virtual void   GetMac(void* pMac, size_t macSize) NN_NOEXCEPT = 0;
    };

public:
    AesGcmSource(std::shared_ptr<ISource> source, std::shared_ptr<IEncryptor> encryptor) NN_NOEXCEPT
        : m_Source(std::move(source))
        , m_Encryptor(std::move(encryptor))
    {
        m_Encryptor->Initialize(&m_Header);
    }

    Result GetMac(char* pOutMacBuffer, size_t outMacBufferSize) NN_NOEXCEPT
    {
        NN_RESULT_THROW_UNLESS(outMacBufferSize == MacSize, fs::ResultInvalidArgument());
        memcpy(pOutMacBuffer, m_Mac, outMacBufferSize);
        NN_RESULT_SUCCESS;
    }

public:

    virtual Result Pull(size_t* outValue, char* buffer, size_t size) NN_NOEXCEPT NN_OVERRIDE
    {
        size_t bufferOffset = 0;
        size_t totalReadSize = 0;
        size_t restBufferSize = size;

        // header
        if (m_RestHeaderSize > 0)
        {
            size_t copySize = static_cast<size_t>(std::min<int64_t>(m_RestHeaderSize, restBufferSize));
            memcpy(buffer + bufferOffset, reinterpret_cast<char*>(&m_Header) + sizeof(m_Header) - m_RestHeaderSize, copySize);
            totalReadSize    += copySize;
            m_RestHeaderSize -= copySize;
            bufferOffset     += copySize;
            restBufferSize   -= copySize;
        }

        // body
        if (!m_IsPullComplete)
        {
            size_t readSize;
            // TODO: ユーザバッファに生データが載る点を念のため改善検討
            NN_RESULT_DO(m_Source->Pull(&readSize, buffer + bufferOffset, restBufferSize));

            auto encryptedSize = m_Encryptor->Update(buffer + bufferOffset, readSize, buffer + bufferOffset, readSize);
            NN_SDK_ASSERT(encryptedSize == readSize);
            NN_UNUSED(encryptedSize);

            if (readSize < restBufferSize)
            {
                m_IsPullComplete = true;
                m_Encryptor->GetMac(m_Mac, sizeof(m_Mac));
                m_RestMacSize = MacSize;
            }

            totalReadSize += readSize;
            bufferOffset += readSize;
            restBufferSize -= readSize;
        }

        // tail: mac
        if (m_IsPullComplete && m_RestMacSize > 0)
        {
            size_t copySize = static_cast<size_t>(std::min<int64_t>(m_RestMacSize, restBufferSize));
            memcpy(buffer + bufferOffset, m_Mac + MacSize - m_RestMacSize, copySize);
            totalReadSize += copySize;
            m_RestMacSize -= copySize;
            bufferOffset += copySize;
            restBufferSize -= copySize;
        }

        *outValue = static_cast<int64_t>(totalReadSize);
        NN_RESULT_SUCCESS;
    }

    virtual Result GetRestRawDataSize(int64_t* outValue) NN_NOEXCEPT NN_OVERRIDE
    {
        return m_Source->GetRestRawDataSize(outValue);
    }

    virtual bool IsEnd() NN_NOEXCEPT NN_OVERRIDE
    {
        return m_IsPullComplete && m_RestMacSize == 0;
    }

private:
    int64_t m_RestHeaderSize = sizeof(AesGcmStreamHeader);
    AesGcmStreamHeader m_Header;
    std::shared_ptr<ISource> m_Source;
    std::shared_ptr<IEncryptor> m_Encryptor;
    bool m_IsPullComplete = false;
    int64_t m_RestMacSize = 0;
    char m_Mac[MacSize];
};

class AesGcmSink : public ISink
{
public:
    static const size_t MacSize = crypto::Aes128GcmEncryptor::MacSize;
    static const size_t BlockSize = crypto::Aes128GcmEncryptor::BlockSize;

public:
    class IDecryptor
    {
    public:
        virtual ~IDecryptor() NN_NOEXCEPT {}
        virtual void   Initialize(const AesGcmStreamHeader& header) NN_NOEXCEPT = 0;
        virtual size_t Update(void* pDst, size_t dstSize, const void* pSrc, size_t srcSize) NN_NOEXCEPT = 0;
        virtual void   GetMac(void* pMac, size_t macSize) NN_NOEXCEPT = 0;
    };

public:
    AesGcmSink(std::shared_ptr<ISink> sink, int64_t size, std::shared_ptr<IDecryptor> decryptor) NN_NOEXCEPT
        : m_Sink(std::move(sink))
        , m_Decryptor(std::move(decryptor))
        , m_RestSize(size + MacSize)
    {
        // TODO: error handling
        m_RawBuffer.Allocate(16 * 1024, 16 * 1024);
    }

public:

    virtual Result Push(const char* buffer, size_t size) NN_NOEXCEPT NN_OVERRIDE
    {
        size_t restSrcSize = size;
        size_t bufferOffset = 0;

        // header
        if (m_RestHeaderSize > 0)
        {
            size_t copySize = static_cast<size_t>(std::min<int64_t>(m_RestHeaderSize, restSrcSize));
            memcpy(reinterpret_cast<char*>(&m_Header) + sizeof(m_Header) - m_RestHeaderSize, buffer + bufferOffset, copySize);
            restSrcSize      -= copySize;
            bufferOffset     += copySize;
            m_RestHeaderSize -= copySize;

            if (m_RestHeaderSize == 0)
            {
                m_Decryptor->Initialize(m_Header);
            }
        }

        while (restSrcSize > 0 && m_RestSize > static_cast<int64_t>(MacSize))
        {
            size_t decryptSize = static_cast<size_t>(std::min(static_cast<int64_t>(m_RestSize - MacSize), static_cast<int64_t>(std::min(restSrcSize, m_RawBuffer.GetSize()))));

            auto decryptedSize = m_Decryptor->Update(m_RawBuffer.GetBuffer(), m_RawBuffer.GetSize(), buffer + bufferOffset, decryptSize);
            NN_SDK_ASSERT(decryptedSize == decryptSize);
            NN_UNUSED(decryptedSize);

            NN_RESULT_DO(m_Sink->Push(m_RawBuffer.GetBuffer(), decryptSize));

            restSrcSize -= decryptSize;
            bufferOffset += decryptSize;
            m_RestSize -= decryptSize;
        }

        // mac
        if (restSrcSize > 0 && m_RestSize > 0)
        {
            NN_SDK_ASSERT(static_cast<int64_t>(MacSize) >= m_RestSize);
            auto macOffset = MacSize - m_RestSize;
            size_t copySize = std::min(restSrcSize, static_cast<size_t>(m_RestSize));

            memcpy(&m_Mac[macOffset], &buffer[bufferOffset], copySize);
            restSrcSize -= copySize;
            bufferOffset += copySize;
            m_RestSize -= copySize;
        }

        NN_RESULT_SUCCESS;
    }

public:
    Result Finalize() NN_NOEXCEPT
    {
        // TODO: Flush stream
        char actualMac[MacSize];
        m_Decryptor->GetMac(&actualMac, sizeof(actualMac));
        NN_RESULT_THROW_UNLESS(crypto::IsSameBytes(actualMac, m_Mac, MacSize), fs::ResultSaveDataTransferImportMacVerificationFailed());

        // TODO: Finalize stream
        //NN_RESULT_DO(m_Stream->Finalize);

        NN_RESULT_SUCCESS;
    }

private:
    int64_t m_RestHeaderSize = sizeof(AesGcmStreamHeader);
    AesGcmStreamHeader m_Header;

    std::shared_ptr<ISink> m_Sink;
    std::shared_ptr<IDecryptor> m_Decryptor;

    int64_t m_RestSize;
    char m_Mac[MacSize];

    fssystem::PooledBuffer m_RawBuffer;
};

}}}

