﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/nn_Common.h>
#include <nn/nn_SdkAssert.h>
#include <nn/nn_StaticAssert.h>
#include <nn/os/os_SdkMutex.h>

#include <nnt/gtest/detail/gtest-heap.h>

namespace nnt { namespace testing { namespace detail {

namespace {

class Alignment final
{
public:
    template<size_t StorageByteCount, size_t AlignmentByteCount>
    struct AlignedStorage;

    template<size_t StorageByteCount>
    struct AlignedStorage<StorageByteCount, 4>
    {
        NN_ALIGNAS(4)
        char data[StorageByteCount];
    };

    template<size_t StorageByteCount>
    struct AlignedStorage<StorageByteCount, 8>
    {
        NN_ALIGNAS(8)
        char data[StorageByteCount];
    };

private:
    Alignment();

    NN_DISALLOW_COPY(Alignment);
    NN_DISALLOW_MOVE(Alignment);
};

template<
    size_t AtomicBlockByteCount, int AtomicBlockCount,
    size_t AlignmentByteCount>
class BuddyMemoryAllocator final
{
private:
    NN_STATIC_ASSERT(AtomicBlockByteCount > 0);
    NN_STATIC_ASSERT(AtomicBlockCount > 0);
    NN_STATIC_ASSERT(AlignmentByteCount >= sizeof(int));
    NN_STATIC_ASSERT(AtomicBlockByteCount % AlignmentByteCount == 0);

    static const int BlockListCountMax = sizeof(int) * 8 - 1;

    static int GetBlockCountOfSize(size_t size) NN_NOEXCEPT
    {
        return static_cast<int>(
            (size + AtomicBlockByteCount - 1) / AtomicBlockByteCount);
    }

    static int GetBlockCountOfBlockOrder(int order) NN_NOEXCEPT
    {
        NN_SDK_REQUIRES_RANGE(order, 0, BlockListCountMax);

        return 1 << order;
    }

    static int GetBlockOrder(int blockCount) NN_NOEXCEPT
    {
        NN_SDK_REQUIRES_RANGE(
            blockCount,
            1,
            GetBlockCountOfBlockOrder(BlockListCountMax - 1) + 1);

        for (int i = 0; i < BlockListCountMax; ++i)
        {
            if (blockCount <= GetBlockCountOfBlockOrder(i))
            {
                return i;
            }
        }

        return -1;
    }

    struct Block final
    {
        Block* pNext;
    };

    class BlockList final
    {
    private:
        Block* m_pHead;
        Block* m_pTail;

    public:
        void Initialize() NN_NOEXCEPT
        {
            this->m_pHead = nullptr;
            this->m_pTail = nullptr;
        }

        bool IsEmpty() const NN_NOEXCEPT
        {
            return (this->m_pHead == nullptr);
        }

        void Enqueue(Block* pBlock) NN_NOEXCEPT
        {
            NN_SDK_REQUIRES_NOT_NULL(pBlock);

            if(this->m_pHead)
            {
                this->m_pTail->pNext = pBlock;
            }
            else
            {
                this->m_pHead = pBlock;
            }

            this->m_pTail = pBlock;
        }

        Block* Dequeue() NN_NOEXCEPT
        {
            Block *pBlock = this->m_pHead;

            if (!pBlock)
            {
                return nullptr;
            }

            this->m_pHead = this->m_pHead->pNext;

            pBlock->pNext = nullptr;

            if (pBlock == this->m_pTail)
            {
                this->m_pTail = nullptr;
            }

            return pBlock;
        }

        bool Remove(Block* pBlock) NN_NOEXCEPT
        {
            NN_SDK_REQUIRES_NOT_NULL(pBlock);

            Block* pPrev = nullptr;
            Block* pCurr = this->m_pHead;

            while (pCurr)
            {
                if (pCurr == pBlock)
                {
                    if (pCurr == this->m_pHead)
                    {
                        this->m_pHead = pCurr->pNext;

                        if (pCurr == this->m_pTail)
                        {
                            this->m_pTail = nullptr;
                        }
                    }
                    else if (pCurr == this->m_pTail)
                    {
                        this->m_pTail = pPrev;
                        this->m_pTail->pNext = nullptr;
                    }
                    else
                    {
                        pPrev->pNext = pCurr->pNext;
                    }

                    pCurr->pNext = nullptr;

                    return true;
                }

                pPrev = pCurr;
                pCurr = pCurr->pNext;
            }

            return false;
        }
    };

    Alignment::AlignedStorage<
        AtomicBlockByteCount * AtomicBlockCount, AlignmentByteCount> m_Storage;

    BlockList m_BlockLists[BlockListCountMax];

    Block* m_pAtomicBlocks;

    int m_BlockOrderMax;

    int GetIndexOfAtomicBlock(const Block& block) const NN_NOEXCEPT
    {
        auto index = static_cast<int>(&block - this->m_pAtomicBlocks);

        NN_SDK_ASSERT_RANGE(index, 0, AtomicBlockCount);

        return index;
    }

    void* GetAddressOfBlock(const Block& block) const NN_NOEXCEPT
    {
        const int offset =
            this->GetIndexOfAtomicBlock(block) * AtomicBlockByteCount;

        return reinterpret_cast<void*>(
            reinterpret_cast<uintptr_t>(this->m_Storage.data) + offset);
    }

    Block *GetBlockOfAddress(const void *p) const NN_NOEXCEPT
    {
        uintptr_t addr = reinterpret_cast<uintptr_t>(p);

        addr -= reinterpret_cast<uintptr_t>(this->m_Storage.data);

        int index = static_cast<int>(addr / AtomicBlockByteCount);

        NN_SDK_ASSERT_RANGE(index, 0, AtomicBlockCount);

        return &(this->m_pAtomicBlocks[index]);
    }

    bool IsAlignedToBlockOrder(
        const Block& block, int order) const NN_NOEXCEPT
    {
        NN_SDK_REQUIRES_RANGE(order, 0, this->m_BlockOrderMax + 1);

        const int count = GetBlockCountOfBlockOrder(order);

        return ((this->GetIndexOfAtomicBlock(block) % count) == 0);
    }

    void DivideBlock(Block *pBlock, int minOrder, int maxOrder) NN_NOEXCEPT
    {
        NN_SDK_REQUIRES_NOT_NULL(pBlock);
        NN_SDK_REQUIRES_RANGE(minOrder, 0, this->m_BlockOrderMax + 1);
        NN_SDK_REQUIRES_RANGE(
            maxOrder, minOrder, this->m_BlockOrderMax + 1);

        for (int i = maxOrder; minOrder < i; --i)
        {
            Block* pBuddyBlock = &pBlock[GetBlockCountOfBlockOrder(i - 1)];

            this->m_BlockLists[i - 1].Enqueue(pBuddyBlock);
        }
    }

    void UniteBlock(Block *pBlock, int order) NN_NOEXCEPT
    {
        NN_SDK_REQUIRES_NOT_NULL(pBlock);
        NN_SDK_REQUIRES_RANGE(order, 0, this->m_BlockOrderMax + 1);

        while (order < this->m_BlockOrderMax - 1)
        {
            const bool isLhs =
                this->IsAlignedToBlockOrder(*pBlock, order + 1);

            const int offset = GetBlockCountOfBlockOrder(order);

            Block* pBuddyBlock = pBlock + (isLhs ? offset : -offset);

            if (!this->m_BlockLists[order].Remove(pBuddyBlock))
            {
                break;
            }
            else
            {
                if (!isLhs)
                {
                    pBlock = pBuddyBlock;
                }

                ++order;
            }
        }

        this->m_BlockLists[order].Enqueue(pBlock);
    }

    Block* GetFreeBlock(int order) NN_NOEXCEPT
    {
        NN_SDK_REQUIRES_RANGE(order, 0, this->m_BlockOrderMax + 1);

        for(int i = order; i < BlockListCountMax; ++i)
        {
            if(!(this->m_BlockLists[i].IsEmpty()))
            {
                Block* pBlock = this->m_BlockLists[i].Dequeue();

                NN_SDK_ASSERT_NOT_NULL(pBlock);

                this->DivideBlock(pBlock, order, i);

                return pBlock;
            }
        }

        return nullptr;
    }

    void* AllocateByBlockOrder(int order) NN_NOEXCEPT
    {
        NN_SDK_REQUIRES_RANGE(order, 0, this->m_BlockOrderMax + 1);

        Block* pBlock = this->GetFreeBlock(order);

        if (pBlock)
        {
            return this->GetAddressOfBlock(*pBlock);
        }
        else
        {
            return nullptr;
        }
    }

    void Free(void* p, int order) NN_NOEXCEPT
    {
        NN_SDK_REQUIRES_NOT_NULL(p);
        NN_SDK_REQUIRES_RANGE(order, 0, this->m_BlockOrderMax + 1);

        Block* pBlock = GetBlockOfAddress(p);

        NN_SDK_ASSERT(this->IsAlignedToBlockOrder(*pBlock, order));

        this->UniteBlock(pBlock, order);
    }

public:
    BuddyMemoryAllocator() NN_NOEXCEPT
    {
        this->m_pAtomicBlocks =
            reinterpret_cast<Block*>(this->m_Storage.data);

        for (int i = 0; i < AtomicBlockCount; i++)
        {
            this->m_pAtomicBlocks[i].pNext = nullptr;
        }

        for (auto blockList : this->m_BlockLists)
        {
            blockList.Initialize();
        }

        this->m_BlockOrderMax = GetBlockOrder(AtomicBlockCount);

        NN_SDK_ASSERT_RANGE(this->m_BlockOrderMax, 0, BlockListCountMax);

        const int count = GetBlockCountOfBlockOrder(this->m_BlockOrderMax);
        if (AtomicBlockCount < count)
        {
            --(this->m_BlockOrderMax);
        }

        this->m_BlockLists[this->m_BlockOrderMax].Enqueue(
            this->m_pAtomicBlocks);

        void *p =
            this->AllocateByBlockOrder(
                GetBlockOrder(
                    GetBlockCountOfSize(sizeof(Block) * AtomicBlockCount)));

        NN_SDK_ASSERT_EQUAL(this->m_pAtomicBlocks, p);
        NN_UNUSED(p);
    }

    void* Allocate(size_t size) NN_NOEXCEPT
    {
        if (size == 0)
        {
            return nullptr;
        }

        int order =
            GetBlockOrder(GetBlockCountOfSize(size + AlignmentByteCount));

        void *p = this->AllocateByBlockOrder(order);

        if (!p)
        {
            return nullptr;
        }

        *static_cast<int*>(p) = order;

        return reinterpret_cast<void*>(
            reinterpret_cast<uintptr_t>(p) + AlignmentByteCount);
    }

    void Free(void* p) NN_NOEXCEPT
    {
        if (!p)
        {
            return;
        }

        p = reinterpret_cast<void*>(
            reinterpret_cast<uintptr_t>(p) - AlignmentByteCount);

        this->Free(p, *static_cast<int*>(p));
    }

private:
    NN_DISALLOW_COPY(BuddyMemoryAllocator);
    NN_DISALLOW_MOVE(BuddyMemoryAllocator);
};

template<
    size_t AtomicBlockByteCount, int AtomicBlockCount,
    size_t AlignmentByteCount>
const int BuddyMemoryAllocator<
    AtomicBlockByteCount, AtomicBlockCount, AlignmentByteCount
    >::BlockListCountMax;

typedef BuddyMemoryAllocator<64, 64 * 1024, sizeof(int)> MemoryAllocator;

MemoryAllocator& GetMemoryAllocator() NN_NOEXCEPT
{
    static MemoryAllocator s_MemoryAllocator;

    return s_MemoryAllocator;
}

class AllocatorLocker final
{
public:
    AllocatorLocker() NN_NOEXCEPT { this->GetMutex().Lock(); }

    ~AllocatorLocker() NN_NOEXCEPT { this->GetMutex().Unlock(); }

private:
    ::nn::os::SdkMutexType& GetMutex() NN_NOEXCEPT;

    NN_DISALLOW_COPY(AllocatorLocker);
    NN_DISALLOW_MOVE(AllocatorLocker);
};

::nn::os::SdkMutexType& AllocatorLocker::GetMutex() NN_NOEXCEPT
{
    static ::nn::os::SdkMutexType s_Mutex = NN_OS_SDK_MUTEX_INITIALIZER();
    return s_Mutex;
}

} // namespace

void *Heap::Allocate(size_t size) NN_NOEXCEPT
{
    AllocatorLocker locker;
    return GetMemoryAllocator().Allocate(size);
}

void Heap::Free(void *p) NN_NOEXCEPT
{
    AllocatorLocker locker;
    GetMemoryAllocator().Free(p);
}

}}} // namespace nnt::testing::detail
