﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <functional>
#include <nn/fssystem/fs_BucketTreeUtility.h>
#include <nn/fssystem/utilTool/fs_StorageChecker.h>

namespace nn { namespace fssystem { namespace utilTool {

namespace {

typedef detail::BucketTreeNode<const int64_t*> Node;

const int BufferSize = 4 * 1024;

}

// BucketTree の検証クラスです。
class BucketTreeChecker
{
public:
    // 整合性を検証します。
    static Result Verify(
                      BucketTree* pBucketTree,
                      std::function<Result(const void*)> func
                  ) NN_NOEXCEPT;
};

// 整合性の検証。主にオフセットが昇順になっているかどうかのチェック
Result BucketTreeChecker::Verify(
                              BucketTree* pBucketTree,
                              std::function<Result(const void*)> func
                          ) NN_NOEXCEPT
{
    NN_RESULT_THROW_UNLESS(pBucketTree != nullptr, fs::ResultNullptrArgument());
    NN_RESULT_THROW_UNLESS(pBucketTree->IsInitialized(), fs::ResultInvalidArgument());

    if( 0 < pBucketTree->m_EntrySize )
    {
        const auto& nodeL1 = *pBucketTree->m_NodeL1.Get<Node>();
        const auto endOffset = pBucketTree->m_EndOffset;

        // L1 ノードを調査
        {
            int64_t offset = -1;

            // L1 ノード内の L2 部分を調査
            if( pBucketTree->IsExistOffsetL2OnL1() )
            {
                auto iter = nodeL1.GetEnd();
                const auto end = nodeL1.GetBegin() + pBucketTree->m_OffsetCount;

                for( ; iter < end; ++iter )
                {
                    NN_RESULT_THROW_UNLESS(offset < *iter, fs::ResultInvalidBucketTreeNodeOffset());
                    offset = *iter;
                }
            }

            // L1 ノードを調査
            const auto end = nodeL1.GetEnd();
            for( auto iter = nodeL1.GetBegin(); iter < end; ++iter )
            {
                NN_RESULT_THROW_UNLESS(offset < *iter, fs::ResultInvalidBucketTreeNodeOffset());
                offset = *iter;
            }

            NN_RESULT_THROW_UNLESS(
                offset < endOffset, fs::ResultInvalidBucketTreeNodeOffset());
        }

        const auto pAllocator = pBucketTree->m_NodeL1.GetAllocator();
        const auto bufferSize = pBucketTree->m_NodeSize;
        const auto buffer = pAllocator->allocate(bufferSize);
        NN_RESULT_THROW_UNLESS(
            buffer != nullptr, fs::ResultAllocationMemoryFailedInBucketTreeCheckerA());
        NN_UTIL_SCOPE_EXIT
        {
            pAllocator->deallocate(buffer, bufferSize);
        };

        const auto nodeSize = pBucketTree->m_NodeSize;
        auto& nodeStorage = pBucketTree->m_NodeStorage;

        // L2 ノードを調査
        if( pBucketTree->IsExistL2() )
        {
            int64_t offset = -1;

            for( int i = 0; i < nodeL1.GetCount(); ++i )
            {
                const auto nodeOffset = (1 + i) * static_cast<int64_t>(nodeSize); // L1 ノード分を +1 する
                NN_RESULT_DO(nodeStorage.Read(nodeOffset, buffer, nodeSize));

                const auto& header = *reinterpret_cast<BucketTree::NodeHeader*>(buffer);
                NN_RESULT_DO(header.Verify(i, nodeSize, sizeof(int64_t)));

                const auto& node = *reinterpret_cast<Node*>(buffer);

                const auto end = node.GetEnd();
                for( auto iter = node.GetBegin(); iter < end; ++iter )
                {
                    NN_RESULT_THROW_UNLESS(offset < *iter, fs::ResultInvalidBucketTreeNodeOffset());
                    offset = *iter;
                }
            }

            NN_RESULT_THROW_UNLESS(
                offset < endOffset, fs::ResultInvalidBucketTreeNodeOffset());
        }

        const auto entrySize = pBucketTree->m_EntrySize;
        auto& entryStorage = pBucketTree->m_EntryStorage;
        int64_t offset = -1;

        // エントリを調査
        const auto entrySetCount = pBucketTree->m_EntrySetCount;
        for( int i = 0; i < entrySetCount; ++i )
        {
            const auto entryOffset = i * static_cast<int64_t>(nodeSize);
            NN_RESULT_DO(entryStorage.Read(entryOffset, buffer, nodeSize));

            const auto& header = *reinterpret_cast<BucketTree::NodeHeader*>(buffer);
            NN_RESULT_DO(header.Verify(i, nodeSize, entrySize));

            util::BytePtr ptr(buffer);
            ptr.Advance(sizeof(BucketTree::NodeHeader));

            const auto entryCount = header.count;
            for( int j = 0; j < entryCount; ++j )
            {
                const auto nextOffset = *ptr.Get<int64_t>();
                NN_RESULT_THROW_UNLESS(offset < nextOffset, fs::ResultInvalidBucketTreeEntryOffset());
                offset = nextOffset;

                NN_RESULT_DO(func(ptr.Get()));

                ptr.Advance(entrySize);
            }

            NN_RESULT_THROW_UNLESS(
                offset < endOffset, fs::ResultInvalidBucketTreeEntryOffset());
        }
    }
    NN_RESULT_SUCCESS;
}

Result IndirectStorageChecker::Verify(IndirectStorage* pStorage) NN_NOEXCEPT
{
    NN_RESULT_THROW_UNLESS(pStorage != nullptr, fs::ResultNullptrArgument());
    NN_RESULT_THROW_UNLESS(pStorage->IsInitialized(), fs::ResultInvalidArgument());

    // 各エントリの検証
    auto func = [=](const void* pEntry) NN_NOEXCEPT -> Result
    {
        const auto& entry = *reinterpret_cast<const IndirectStorage::Entry*>(pEntry);

        NN_RESULT_THROW_UNLESS(
            pStorage->m_Table.IsInclude(entry.GetVirtualOffset()),
            fs::ResultInvalidIndirectVirtualOffset()
        );

        NN_RESULT_THROW_UNLESS(
            0 <= entry.GetPhysicalOffset(),
            fs::ResultInvalidIndirectPhysicalOffset()
        );

        NN_RESULT_THROW_UNLESS(
            0 <= entry.storageIndex && entry.storageIndex < IndirectStorage::StorageCount,
            fs::ResultInvalidIndirectStorageIndex()
        );

        NN_RESULT_SUCCESS;
    };

    NN_RESULT_DO(BucketTreeChecker::Verify(&pStorage->m_Table, func));

    NN_RESULT_SUCCESS;
}

Result AesCtrCounterExtendedStorageChecker::Verify(
                                                AesCtrCounterExtendedStorage* pStorage,
                                                uint32_t generation
                                            ) NN_NOEXCEPT
{
    NN_RESULT_THROW_UNLESS(pStorage != nullptr, fs::ResultNullptrArgument());
    NN_RESULT_THROW_UNLESS(pStorage->IsInitialized(), fs::ResultInvalidArgument());

    // 各エントリの検証
    auto func = [=](const void* pEntry) NN_NOEXCEPT -> Result
    {
        const auto& entry =
            *reinterpret_cast<const AesCtrCounterExtendedStorageChecker::Entry*>(pEntry);

        NN_RESULT_THROW_UNLESS(
            pStorage->m_Table.IsInclude(entry.GetOffset()),
            fs::ResultInvalidAesCtrCounterExtendedOffset()
        );

        NN_RESULT_THROW_UNLESS(
            entry.reserved == 0,
            fs::ResultAesCtrCounterExtendedStorageCorrupted()
        );

        NN_RESULT_THROW_UNLESS(
            0 <= entry.generation && static_cast<uint32_t>(entry.generation) <= generation,
            fs::ResultInvalidAesCtrCounterExtendedGeneration()
        );

        NN_RESULT_SUCCESS;
    };

    NN_RESULT_DO(BucketTreeChecker::Verify(&pStorage->m_Table, func));

    NN_RESULT_SUCCESS;
}

AesCtrCounterExtendedStorageChecker::AesCtrCounterExtendedStorageChecker() NN_NOEXCEPT
    : m_Buffer1(new char[BufferSize])
    , m_Buffer2(new char[BufferSize])
{
    m_ResultRange.Reset();
}

Result AesCtrCounterExtendedStorageChecker::Verify(
                                                const StorageInfo& previous,
                                                const StorageInfo& current
                                            ) NN_NOEXCEPT
{
    NN_RESULT_THROW_UNLESS(0 <= previous.dataOffset, fs::ResultInvalidOffset());
    NN_RESULT_THROW_UNLESS(previous.pDataStorage != nullptr, fs::ResultNullptrArgument());
    NN_RESULT_THROW_UNLESS(previous.pCheckStorage != nullptr, fs::ResultNullptrArgument());
    NN_RESULT_THROW_UNLESS(previous.pCheckStorage->IsInitialized(), fs::ResultInvalidArgument());
    NN_RESULT_THROW_UNLESS(0 <= current.dataOffset, fs::ResultInvalidOffset());
    NN_RESULT_THROW_UNLESS(current.pDataStorage != nullptr, fs::ResultNullptrArgument());
    NN_RESULT_THROW_UNLESS(current.pCheckStorage != nullptr, fs::ResultNullptrArgument());
    NN_RESULT_THROW_UNLESS(current.pCheckStorage->IsInitialized(), fs::ResultInvalidArgument());
    NN_RESULT_THROW_UNLESS(
        m_Buffer1 != nullptr && m_Buffer2 != nullptr,
        fs::ResultAllocationMemoryFailedInAesCtrCounterExtendedStorageCheckerA()
    );

    m_ResultRange.Reset();

    // パッチ間の世代番号が不正
    NN_RESULT_THROW_UNLESS(
        previous.generation < current.generation,
        fs::ResultInvalidAesCtrCounterExtendedGeneration()
    );

    const auto& previousTable = previous.pCheckStorage->m_Table;
    const auto& currentTable = current.pCheckStorage->m_Table;

    BucketTree::Visitor currentVisitor;
    NN_RESULT_DO(currentTable.Find(&currentVisitor, 0));

    int64_t nextOffset = 0;
    auto endOffset = std::min(
        previous.dataOffset + previousTable.GetSize(),
        current.dataOffset + currentTable.GetSize()) - current.dataOffset;

    if( 0 <= endOffset )
    {
        auto entry = *currentVisitor.Get<Entry>();

        // previous と current が被る領域をチェック
        do
        {
            uint32_t generation = entry.generation;

            NN_RESULT_THROW_UNLESS(
                0 <= entry.generation && generation <= current.generation,
                fs::ResultInvalidAesCtrCounterExtendedGeneration()
            );

            const auto offset = entry.GetOffset();

            if( currentVisitor.CanMoveNext() )
            {
                NN_RESULT_DO(currentVisitor.MoveNext());

                const auto entry2 = *currentVisitor.Get<Entry>();
                nextOffset = entry2.GetOffset();

                entry = entry2;
            }
            else
            {
                nextOffset = currentTable.GetSize();
            }

            const Range range =
            {
                offset,
                std::min(endOffset, nextOffset) - offset
            };
            NN_RESULT_DO(CheckRange(previous, current, range, generation));
        }
        while( nextOffset < endOffset );
    }

    // current のはみ出た部分をチェック
    if( endOffset < nextOffset )
    {
        const auto& entry = *currentVisitor.Get<Entry>();

        NN_RESULT_THROW_UNLESS(
            static_cast<uint32_t>(entry.generation) == current.generation,
            fs::ResultInvalidAesCtrCounterExtendedGeneration()
        );
    }

    // current の残りの部分をチェック
    while( currentVisitor.CanMoveNext() )
    {
        NN_RESULT_DO(currentVisitor.MoveNext());

        const auto& entry = *currentVisitor.Get<Entry>();

        NN_RESULT_THROW_UNLESS(
            static_cast<uint32_t>(entry.generation) == current.generation,
            fs::ResultInvalidAesCtrCounterExtendedGeneration()
        );
    }

    NN_RESULT_SUCCESS;
}

Result AesCtrCounterExtendedStorageChecker::CheckRange(
                                                const StorageInfo& previous,
                                                const StorageInfo& current,
                                                Range range,
                                                uint32_t generation
                                            ) NN_NOEXCEPT
{
    const auto& previousTable = previous.pCheckStorage->m_Table;
    auto previousOffset = range.offset + (current.dataOffset - previous.dataOffset);
    auto remaining = range;

    if( 0 <= previousOffset )
    {
        while( previousOffset < previousTable.GetSize() )
        {
            BucketTree::Visitor previousVisitor;
            NN_RESULT_DO(previousTable.Find(&previousVisitor, previousOffset));

            int64_t endOffset = 0;
            if( previousVisitor.CanMoveNext() )
            {
                NN_RESULT_DO(previousVisitor.MoveNext());

                endOffset = previousVisitor.Get<Entry>()->GetOffset();
            }
            else
            {
                endOffset = previousTable.GetSize();
            }

            // 世代番号が古いので前パッチとデータが同じはず
            // NOTE: 世代番号が最新であれば、データは同じでも違っても問題ない
            if( generation < current.generation )
            {
                auto remainingSize = std::min(remaining.size, endOffset - previousOffset);
                auto& storage1 = *previous.pDataStorage;
                auto offset1 = previousOffset;
                auto& storage2 = *current.pDataStorage;
                auto offset2 = remaining.offset;

                while( 0 < remainingSize )
                {
                    const auto readSize =
                        static_cast<size_t>(std::min<int64_t>(remainingSize, BufferSize));

                    NN_RESULT_DO(storage1.Read(offset1, m_Buffer1.get(), readSize));
                    NN_RESULT_DO(storage2.Read(offset2, m_Buffer2.get(), readSize));

                    if( std::memcmp(m_Buffer1.get(), m_Buffer2.get(), readSize) != 0 )
                    {
                        m_ResultRange = range;

                        NN_RESULT_THROW(fs::ResultInvalidAesCtrCounterExtendedGeneration());
                    }

                    offset1 += readSize;
                    offset2 += readSize;
                    remainingSize -= readSize;
                }
            }

            // 全データのチェック完了
            if( remaining.size <= endOffset - previousOffset )
            {
                NN_RESULT_SUCCESS;
            }

            // 残りの領域
            remaining.offset += endOffset - previousOffset;
            remaining.size -= endOffset - previousOffset;

            previousOffset = endOffset;
        }
    }

    // current にしかないデータ
    if( generation != current.generation )
    {
        m_ResultRange = range;

        NN_RESULT_THROW(fs::ResultInvalidAesCtrCounterExtendedGeneration());
    }

    NN_RESULT_SUCCESS;
}

}}}
