﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#pragma once

#include <nn/nn_Common.h>
#include "kern_Assert.h"
#include "kern_KTaggedAddress.h"

#if defined NN_SDK_BUILD_DEVELOP || defined NN_SDK_BUILD_DEBUG
#define NN_KERN_HEAP_DEBUG
#endif

#ifdef NN_KERN_HEAP_DEBUG
#define NN_KERN_HEAP_ASSERT(arg)  NN_KERN_ABORT_UNLESS(arg)
#else
#define NN_KERN_HEAP_ASSERT(arg)
#endif

namespace nn { namespace kern {
    class KPageGroup;
    class KPageHeap
    {
    public:
        enum
        {
            MaxNumBlockPages = 8,
            PageSize = NN_KERN_FINEST_PAGE_SIZE
        };
        static size_t CalcManagementAreaSize(size_t heapSize, const size_t blockPageShift[], size_t numBlockPages);

    private:
        class FreeBlocks
        {
        private:
            class Bitmap
            {
            public:
                Bitmap() : m_NumBit(0), m_Depth(0)
                {
                    for (size_t depth = 0; depth < NN_ARRAY_SIZE(m_pBitArray); depth++)
                    {
                        m_pBitArray[depth] = nullptr;
#ifdef NN_KERN_HEAP_DEBUG
                        m_pBitArrayEnd[depth] = nullptr;
#endif
                    }
                }
                static size_t CalcSize(uintptr_t length)
                {
                    uintptr_t high = length;
                    int numDepth = 0;
                    for (int depth = 0;; depth++)
                    {
                        high /= NN_BITSIZEOF(Bit64);
                        if (high == 0)
                        {
                            numDepth = depth + 1;
                            break;
                        }
                    }

                    high = length;
                    size_t num = 0;
                    for (int depth = numDepth - 1; depth >= 0; depth--)
                    {
                        high = (high + NN_BITSIZEOF(Bit64) - 1) / NN_BITSIZEOF(Bit64);
                        num += high;
                    }
                    return num * sizeof(Bit64);
                }

                /*
                 * (64 ^ depth - 1) / (64 - 1) 個の要素が必要だが、
                 * 後ろはアクセスされないので省略可能
                 */
                Bit64* Initialize(uintptr_t length, Bit64* pBitArray)
                {
                    uintptr_t high = length;
                    for (int depth = 0;; depth++)
                    {
                        high /= NN_BITSIZEOF(Bit64);
                        if (high == 0)
                        {
                            m_Depth = depth + 1;
                            break;
                        }
                    }

                    NN_KERN_HEAP_ASSERT(static_cast<size_t>(m_Depth) <= NN_ARRAY_SIZE(m_pBitArray));
                    high = length;
                    for (int depth = m_Depth - 1; depth >= 0; depth--)
                    {
                        m_pBitArray[depth] = pBitArray;
                        high = (high + NN_BITSIZEOF(Bit64) - 1) / NN_BITSIZEOF(Bit64);
                        pBitArray += high;
#ifdef NN_KERN_HEAP_DEBUG
                        m_pBitArrayEnd[depth] = pBitArray;
#endif
                    }
                    return pBitArray;
                }
                void BitOn(uintptr_t offset)
                {
                    m_NumBit++;
                    BitOn(m_Depth - 1, offset);
                }
                void BitOff(uintptr_t offset)
                {
                    m_NumBit--;
                    BitOff(m_Depth - 1, offset);
                }
                size_t GetNumBits() const { return m_NumBit; }
                size_t GetNumBits(uintptr_t begin, uintptr_t end) const
                {
                    const Bit64* pBitArray = m_pBitArray[m_Depth - 1];
                    size_t count = 0;

                    while (begin < end)
                    {
                        uintptr_t n = (begin / NN_BITSIZEOF(Bit64));
                        Bit64 x = pBitArray[n];
                        if ((begin % NN_BITSIZEOF(Bit64)) != 0)
                        {
                            x &= ~((1ull << (begin % NN_BITSIZEOF(Bit64))) - 1);
                        }

                        if ((end / NN_BITSIZEOF(Bit64)) == n)
                        {
                            if ((end % NN_BITSIZEOF(Bit64)) != 0)
                            {
                                x &= ((1ull << (end % NN_BITSIZEOF(Bit64))) - 1);
                            }
                            else
                            {
                                x = 0;
                            }
                        }
                        count += __builtin_popcountll(x);
                        begin = (begin + NN_BITSIZEOF(Bit64)) & ~(NN_BITSIZEOF(Bit64) - 1);
                    }
                    return count;
                }
                intptr_t SearchHigh() const
                {
                    int depth = 0;
                    uintptr_t offset = 0;

                    do
                    {
                        Bit64 v = m_pBitArray[depth][offset];
                        if (v == 0)
                        {
                            NN_KERN_HEAP_ASSERT(depth == 0);
                            return -1;
                        }
                        offset = NN_BITSIZEOF(Bit64) * offset + 63 - __builtin_clzll(v);
                    } while (++depth < m_Depth);

                    return offset;
                }
                intptr_t SearchLow() const
                {
                    int depth = 0;
                    uintptr_t offset = 0;

                    do
                    {
                        Bit64 v = m_pBitArray[depth][offset];
                        if (v == 0)
                        {
                            NN_KERN_HEAP_ASSERT(depth == 0);
                            return -1;
                        }
                        offset = NN_BITSIZEOF(Bit64) * offset + __builtin_ctzll(v);
                    } while (++depth < m_Depth);

                    return offset;
                }

                bool ClearContiguous(uintptr_t offset, size_t width)
                {
                    int depth = m_Depth - 1;
                    Bit64* pBitArray = m_pBitArray[depth];
                    size_t index = (offset / NN_BITSIZEOF(Bit64));
                    if (NN_UNLIKELY(width >= NN_BITSIZEOF(Bit64)))
                    {
                        NN_KERN_HEAP_ASSERT((width % NN_BITSIZEOF(Bit64)) == 0);
                        NN_KERN_HEAP_ASSERT((offset % NN_BITSIZEOF(Bit64)) == 0);
                        size_t left = width;
                        size_t i = 0;
                        do
                        {
                            if (pBitArray[index + i++] != ~0ull)
                            {
                                return false;
                            }
                            left -= NN_BITSIZEOF(Bit64);
                        } while (left > 0);

                        left = width;
                        i = 0;
                        do
                        {
                            pBitArray[index + i] = 0;
                            BitOff(depth - 1, index + i);
                            i++;
                            left -= NN_BITSIZEOF(Bit64);
                        } while (left > 0);
                    }
                    else
                    {
                        int shift = offset % NN_BITSIZEOF(Bit64);
                        NN_KERN_HEAP_ASSERT(shift + width <= NN_BITSIZEOF(Bit64));
                        Bit64  mask = (((1ull << width) - 1) << shift);
                        Bit64  x = pBitArray[index];
                        if ((x & mask) != mask)
                        {
                            return false;
                        }
                        x &= ~mask;
                        pBitArray[index] = x;
                        if (x == 0)
                        {
                            BitOff(depth - 1, index);
                        }
                    }
                    m_NumBit -= width;
                    return true;
                }
            private:
                void BitOn(int depth, uintptr_t offset)
                {
                    while (depth >= 0)
                    {
                        uintptr_t n = (offset / NN_BITSIZEOF(Bit64));
                        int bit = (offset % NN_BITSIZEOF(Bit64));
                        Bit64* pBit = &m_pBitArray[depth][n];
                        NN_KERN_HEAP_ASSERT(pBit < m_pBitArrayEnd[depth]);

                        Bit64 v = *pBit;
                        NN_KERN_HEAP_ASSERT((v & (1ull << bit)) == 0);
                        *pBit = v | (1ull << bit);
                        if (v != 0)
                        {
                            break;
                        }
                        offset = n;
                        depth--;
                    }
                }
                void BitOff(int depth, uintptr_t offset)
                {
                    while (depth >= 0)
                    {
                        uintptr_t n = (offset / NN_BITSIZEOF(Bit64));
                        int bit = (offset % NN_BITSIZEOF(Bit64));
                        Bit64* pBit = &m_pBitArray[depth][n];
                        NN_KERN_HEAP_ASSERT(pBit < m_pBitArrayEnd[depth]);

                        Bit64 v = *pBit;
                        NN_KERN_HEAP_ASSERT((v & (1ull << bit)) != 0);
                        v &= ~(1ull << bit);
                        *pBit = v;
                        if (v != 0)
                        {
                            break;
                        }
                        offset = n;
                        depth--;
                    }
                }

            private:
                Bit64* m_pBitArray[4];
#ifdef NN_KERN_HEAP_DEBUG
                Bit64* m_pBitArrayEnd[NN_ARRAY_SIZE(m_pBitArray)];
#endif
                size_t m_NumBit;
                int m_Depth;
            };

        public:
            KVirtualAddress Pop(bool fromBack)
            {
                intptr_t offset;
                if (fromBack)
                {
                    offset = m_Bitmap.SearchHigh();
                    if (offset < 0 || static_cast<uintptr_t>(offset) < m_BackOffset)
                    {
                        return Null<KVirtualAddress>();
                    }
                }
                else
                {
                    offset = m_Bitmap.SearchLow();
                    if (offset < 0 || m_BackOffset <= static_cast<uintptr_t>(offset))
                    {
                        return Null<KVirtualAddress>();
                    }
                }
                m_Bitmap.BitOff(offset);
                return m_HeapBase + (offset << m_BlockShift);
            }

            KVirtualAddress Push(KVirtualAddress addr)
            {
                intptr_t offset = ((addr - m_HeapBase) >> m_BlockShift);
                m_Bitmap.BitOn(offset);
                if (m_NextBlockShift)
                {
                    size_t width = (1 << (m_NextBlockShift - m_BlockShift));
                    offset &= ~(width - 1);
                    if (m_Bitmap.ClearContiguous(offset, width))
                    {
                        return m_HeapBase + (offset << m_BlockShift);
                    }
                }
                return Null<KVirtualAddress>();
            }

            static size_t CalcSize(size_t size, int blockShift, int nextBlockShift)
            {
                size_t align = ((nextBlockShift == 0)? (1ul << blockShift): (1ul << nextBlockShift));
                size = ((size + (align - 1)) & ~(align - 1));
                size += align * 2;
                return Bitmap::CalcSize(size / (1ul << blockShift));
            }
            Bit64* Initialize(KVirtualAddress base, size_t size, int blockShift, int nextBlockShift, KVirtualAddress backStart, Bit64* pBitArray)
            {
                KVirtualAddress end = base + size;
                m_BlockShift = blockShift;
                m_NextBlockShift = nextBlockShift;

                size_t align = ((nextBlockShift == 0)? (1ul << blockShift): (1ul << nextBlockShift));
                base = (GetAsInteger(base) & ~(align - 1));
                end = ((GetAsInteger(end) + (align - 1)) & ~(align - 1));
                backStart = (GetAsInteger(backStart) & ~((1ul << blockShift) - 1));
                m_HeapBase = base;

                m_BackOffset = (backStart - m_HeapBase) / (1ul << blockShift);
                m_EndOffset = (end - m_HeapBase) / (1ul << blockShift);
                return m_Bitmap.Initialize((end - base) / (1ul << blockShift), pBitArray);
            }

            int GetBlockShift()         const { return m_BlockShift; }
            int GetNextBlockShift()     const { return m_NextBlockShift; }
            size_t GetBlockPages()      const { return (1ul << m_BlockShift) / PageSize; }
            size_t GetBlockSize()       const { return (1ul << m_BlockShift); }
            size_t GetNumFreeBlocks()   const { return m_Bitmap.GetNumBits(); }
            size_t GetNumFreePages()    const { return GetNumFreeBlocks() * GetBlockPages(); }
            size_t CountFreeBlocks() const
            {
                return m_Bitmap.GetNumBits(0, m_EndOffset);
            }
            size_t CountFrontFreeBlocks() const
            {
                return m_Bitmap.GetNumBits(0, m_BackOffset);
            }
            size_t CountBackFreeBlocks() const
            {
                return m_Bitmap.GetNumBits(m_BackOffset, m_EndOffset);
            }

            FreeBlocks():
                m_HeapBase(Null<KVirtualAddress>()),
                m_BackOffset(0),
                m_EndOffset(0),
                m_BlockShift(0),
                m_NextBlockShift(0)
            {}
        private:
            Bitmap m_Bitmap;
            KVirtualAddress m_HeapBase;
            uintptr_t m_BackOffset;
            uintptr_t m_EndOffset;
            int m_BlockShift;
            int m_NextBlockShift;
        };

    public:
        KPageHeap() :
            m_HeapStart(Null<KVirtualAddress>()),
            m_HeapSize(0),
            m_NumBlockPages(0)
        {}
        void                Initialize(KVirtualAddress heapStart, size_t heapSize, KVirtualAddress heapBackStart, KVirtualAddress managementStart, size_t managementSize, const size_t blockPageShift[], size_t numBlockPages);
        Result              Allocate(KPageGroup* pOut, size_t numPages, bool fromBack);
        KVirtualAddress     AllocateContinuous(size_t numPages, size_t pageAlign, bool fromBack);
        void                Free(KVirtualAddress addr, size_t numPages);
        void                FreeImpl(KVirtualAddress addr, size_t numPages);
        size_t              GetNumFreePages() const;
        KVirtualAddress     GetHeapStartAddress() const { return m_HeapStart; }
        size_t              GetHeapSize() const { return m_HeapSize; }
        KVirtualAddress     GetHeapEndAddress() const { return m_HeapStart + m_HeapSize; }
        void                DumpFreeList();
        bool                Includes(KVirtualAddress a) const { return (m_HeapStart <= a) && (a < m_HeapStart + m_HeapSize); }

    private:
        NN_FORCEINLINE int AdaptiveBlock(size_t numPages) const;
        NN_FORCEINLINE int AlignedBlock(size_t numPages) const;
        NN_FORCEINLINE void FreeBlock(KVirtualAddress addr, int index);
        NN_FORCEINLINE KVirtualAddress AllocateBlock(int index, bool fromBack);

    private:
        KVirtualAddress     m_HeapStart;
        KVirtualAddress     m_BackStart;
        size_t              m_HeapSize;
        int                 m_NumBlockPages;

        FreeBlocks          m_FreeBlocks[MaxNumBlockPages];
    };
}}

