﻿#pragma once

#include <stdlib.h>
#include <stdint.h>

#ifdef _WIN32
#include <Windows.h>
#endif

typedef unsigned   int   uint_t;

// Some handy macros to help you out.
#define KiB(size) ((size)*1024ULL)
#define MiB(size) ((size)*1024ULL*1024ULL)

#define MEMORY_STATUS_ENABLED           1

#if defined(_WIN32)
#if defined(_M_IX86)
#define MEM_PLATFORM_TOKEN              _WIN32
#define MEM_ADDRESS_BIT_COUNT           32
#define VIRTUAL_PAGE_SIZE               (KiB(4))
#define VIRTUAL_PAGE_INDEX(Addr)        ((Addr)>>12)
#define VIRTUAL_PAGE_MASK(Addr)         ((Addr)&0xFFFFF000)
#elif defined(_M_X64)
#define MEM_PLATFORM_TOKEN              _WIN32
#define MEM_ADDRESS_BIT_COUNT           64
#define VIRTUAL_PAGE_SIZE               (KiB(64))
#define VIRTUAL_PAGE_INDEX(Addr)        ((Addr)>>16)
#define VIRTUAL_PAGE_MASK(Addr)         ((Addr)&0xFFFFFFFFFFFF0000)
#else
#error Unsupported platform.
#endif

#else
#define MEM_PLATFORM_TOKEN              1
#define MEM_ADDRESS_BIT_COUNT           32
#define VIRTUAL_PAGE_SIZE               (KiB(4))
#define VIRTUAL_PAGE_INDEX(Addr)        ((Addr)>>12)
#define VIRTUAL_PAGE_MASK(Addr)         ((Addr)&0xFFFFF000)
#define ASSERT(...) { if( !(__VA_ARGS__) ) __builtin_trap(); }
#endif


#if (MEM_ADDRESS_BIT_COUNT == 32)
#define VIRTUAL_BIN_COUNT (1536)
#elif (MEM_ADDRESS_BIT_COUNT == 64)
#define VIRTUAL_BIN_COUNT (3647)
#else
#error Undefined bit count.
#endif


#ifndef MEM_PLATFORM_TOKEN
#error Platform not defined.
#endif

#ifdef _DEBUG
// Heap validation is on by default when building _DEBUG.
#define MEM_VALIDATION_ON 1
#else
#define MEM_VALIDATION_ON 0
#endif

#if MEM_VALIDATION_ON
#define MEM_ASSERT(cond) (void)((cond) || __memory_assert())
#define MEM_ASSERT_ONCE(cond) do { static bool just_once = false; if (!just_once) { just_once = true; MEM_ASSERT((cond)); }} while(0)
#else
#define MEM_ASSERT(cond)
#define MEM_ASSERT_ONCE(cond)
#endif

// __memory_assert() will stop the application like a debug_break.
bool    __memory_assert     (void);


// These are implementation-specific functions that provide basic memory mapping.
extern void     SystemVirtualFree           (void* pAddress, size_t Size);
extern void     SystemVirtualDecommit       (void* pAddress, size_t Size);
extern void*    SystemVirtualCommit         (void* pAddress, size_t Size);
extern void*    SystemVirtualReserve        (size_t Size);


#ifdef _WIN32

// On Windows the mutex is implemented as spin-lock critical section.
class MemMutex
{
public:
    MemMutex();
    ~MemMutex();

    void    Lock();
    void    Unlock();
    bool    Locked();

    CRITICAL_SECTION Crit;
    int     LockedCount;
};

#else

class MemMutex
{
public:
    MemMutex();
    ~MemMutex();

    void    Lock();
    void    Unlock();
    bool    Locked();

    //SceKernelLwMutexWork LwMutex __attribute__((aligned(8)));
    int     LockedCount;
} __attribute__((aligned(16)));

#endif

struct LockGuard
{
    LockGuard(MemMutex& Mutex) : mMutex(Mutex) {
        mMutex.Lock();
    }
    ~LockGuard() {
        mMutex.Unlock();
    }
    MemMutex& mMutex;
};



// The allocator

namespace px
{
namespace Memory
{
    class Allocator
    {
    public:

        enum {
            kVirtualPageSize        = VIRTUAL_PAGE_SIZE,    // What is the minimum page size supported by this allocator, regardless of platform or OS specific requirements.
            kVirtualSmallBlockSize  = 512,                  // Small block threshold; anything larger will NOT go into the small block allocator.
            kVirtualSmallBinCount   = (kVirtualSmallBlockSize >> 3) + 1,
            kVirtualBinFactor       = 256,
            kVirtualBinCount        = VIRTUAL_BIN_COUNT,
        };


        void    Shutdown();
        bool    Initialize(uintptr_t VirtualAddressRange, int VirtualHashTableSize = 10657);
        void*   VirtualAlloc(uintptr_t Size, uint_t Alignment, const char* pName);
        void*   VirtualRealloc(void* MemoryPtr, uintptr_t NewSize, uint_t Alignment, const char* pName);
        void*   VirtualResize(void* MemoryPtr, uintptr_t NewSize, uint_t Alignment, const char* pName);
        bool    VirtualFree(void* MemoryPtr);
        bool    VirtualValidateHeap();
        void*   VirtualGetBaseAddress();
        uintptr_t VirtualGetAddressRange();

        Allocator(Allocator* pInternalAllocator = NULL);
        ~Allocator();

        struct reference_counter_t
        {
#if MEMORY_STATUS_ENABLED
            reference_counter_t()                   : Counter(0), Peak(0) {}
            __inline void Reset()                   { Counter = Peak = 0; }
            __inline void Increment()               { if (++Counter > Peak) Peak = Counter; }
            __inline void Decrement()               { --Counter; }
            __inline void Set(uintptr_t Value)      { Counter = Value; if (Counter > Peak) Peak = Counter; }
            __inline void Add(uintptr_t Count)      { Counter+= Count; if (Counter > Peak) Peak = Counter; }
            __inline void Subtract(uintptr_t Count) { Counter-= Count; }
            __inline void Diff(reference_counter_t &a, reference_counter_t &b) { Set(a.Counter - b.Counter); }
            __inline void ResetPeak()               { Peak = Counter; }
            uintptr_t Counter;
            uintptr_t Peak;
#else
            reference_counter_t()                   {}
            __inline void Reset()                   {}
            __inline void Increment()               {}
            __inline void Decrement()               {}
            __inline void Set(uintptr_t Value)      {}
            __inline void Add(uintptr_t Count)      {}
            __inline void Subtract(uintptr_t Count) {}
            __inline void Diff(reference_counter_t &a, reference_counter_t &b) {}
            __inline void ResetPeak()               {}
#endif
        };


        //
        // Memory Status / Performance Counters
        //

        struct MemoryStatus
        {
            reference_counter_t  VirtualBytesReserved;
            reference_counter_t  VirtualBytesComitted;
            reference_counter_t  VirtualBytesAllocated;
            reference_counter_t  VirtualBytesUnused;
            reference_counter_t  VirtualBlocksReserved;
            reference_counter_t  VirtualBlocksAllocated;
            reference_counter_t  VirtualBlocksUnused;
            reference_counter_t  VirtualBlockFragments;
            reference_counter_t  VirtualPagesReserved;
            reference_counter_t  VirtualPagesComitted;
            reference_counter_t  VirtualPagesUnused;
            reference_counter_t  VirtualLargestExtent;
            reference_counter_t  VirtualSmallBlockPagesAllocated;
            reference_counter_t  VirtualSmallBlockBlocksReserved;
            reference_counter_t  VirtualSmallBlockBlocksAllocated;
            reference_counter_t  VirtualSmallBlockBlocksUnused;
            reference_counter_t  VirtualSmallBlockBytesReserved;
            reference_counter_t  VirtualSmallBlockBytesAllocated;
            reference_counter_t  VirtualSmallBlockBytesUnused;
            reference_counter_t  SysMemBytesAllocated;
        };

        void GetMemoryStatus(MemoryStatus& st);

    private:
        // Data structures

        enum {
            BLOCK_FREE = 1,
        };

        //
        // VirtualBlock
        //
        //  This block structure describes allocated memory.  Every allocation is associated
        //  with exactly one virtual block.  Small blocks allocations are excluded.
        //

        struct VirtualBlock
        {
            union {
                VirtualBlock* pNextFree;
                VirtualBlock* pGlobalNext;
            };

            VirtualBlock* pGlobalPrev;
            VirtualBlock* pBinNext;
            VirtualBlock* pBinPrev;
            uintptr_t MemoryPtr;
        private: // This enforces the use of GetMemorySize() and SetMemorySize() API
            uintptr_t MemorySize:MEM_ADDRESS_BIT_COUNT-2;
        public:
            uintptr_t Flags:2;

            uintptr_t GetMemorySize()                  { return MemorySize << 2; }
            void SetMemorySize(uintptr_t Val)          {        MemorySize = Val>>2; }
        };

        // VirtualBin
        struct VirtualBin
        {
            VirtualBlock* pFreeList;
        };

        // VirtualPage
        struct VirtualPage
        {
            unsigned short BlockSize;
            unsigned short BlockCount;
        };

        // SmallBlock
        struct SmallBlock
        {
            SmallBlock* pPrev;
            SmallBlock* pNext;
        };

        // SmallBlockBin
        struct SmallBlockBin
        {
            MemMutex mMutex;
            SmallBlock* pFreeList;
        };

        void            Construct();
        uint32_t        VirtualGetBinIndex(uintptr_t Size);
        uintptr_t       VirtualGetPageIndex(uintptr_t MemoryPtr);
        unsigned int    VirtualGetSmallBinIndex(int Size);
        void*           VirtualAllocSmallBlock(uint_t Size, uint_t Alignment);
        void*           VirtualAllocInternal(uintptr_t Size, uint_t Alignment);
        bool            VirtualFreeSmallBlock(void* MemoryPtr);
        void            VirtualFreeInternal(void* MemoryPtr);
        VirtualBlock*   VirtualCreateBlock(uintptr_t BlockPtr, uintptr_t BlockSize, uint_t Flags, VirtualBlock* pFriendBlock = 0);
        void            VirtualDestroyBlock(VirtualBlock* pBlock);
        void            VirtualInsertFreeBlock(VirtualBlock* pBlock, int BinIndex);
        void            VirtualRemoveFreeBlock(VirtualBlock* pBlock, int BinIndex);
        int             VirtualGetHash(uintptr_t BlockPtr);
        void            VirtualHashInsert(VirtualBlock* pBlock);
        VirtualBlock*   VirtualHashRemove(uintptr_t BlockPtr);
        VirtualBlock*   VirtualHashLookup(uintptr_t BlockPtr);
        void            VirtualMapPages(uintptr_t BlockPtr, uintptr_t BlockSize);
        void            VirtualUnmapPages(uintptr_t BlockPtr, uintptr_t BlockSize);
        void            Dump(const char* msg);

        //
        // The internals
        //

        VirtualBin          mVirtualFreeBins[kVirtualBinCount];
        SmallBlockBin       mVirtualSmallBins[kVirtualSmallBinCount];
        VirtualBlock        mVirtualInitialBlockPool[16];
        VirtualBlock*       mVirtualInitialHashTable[7];
        MemMutex            mVirtualMutex;
        uintptr_t           mVirtualBaseAddress;
        uintptr_t           mVirtualMemorySize;
        VirtualPage*        mpVirtualPageArray;
        uintptr_t           mVirtualPageCount;
        VirtualBlock*       mpVirtualBlockList;
        VirtualBlock**      mVirtualHashTable;
        uint_t              mVirtualHashTableSize;
        Allocator*          mpInternalAllocator;

        MemoryStatus        mMemoryStatus;

        //
        // Free block cluster pools
        //

        struct VirtualBlockCluster
        {
            VirtualBlockCluster* pNext;
            VirtualBlock Blocks[1];
        };

        VirtualBlockCluster*    mpVirtualBlockClusterChain;
        VirtualBlock*           mpVirtualFreeBlockPool;
        uint_t                  mVirtualFreeBlockCount;
    };



    //
    // force inline speed-critical functions
    //

    __inline int64_t Hash64(int64_t key) // TODO: should we use this for the hash table?
    {
        key = (~key) + (key << 21); // key = (key << 21) - key - 1;
        key = key ^ (key >> 24);
        key = (key + (key << 3)) + (key << 8); // key * 265
        key = key ^ (key >> 14);
        key = (key + (key << 2)) + (key << 4); // key * 21
        key = key ^ (key >> 28);
        key = key + (key << 31);
        return key;
    }

    __inline uint32_t Allocator::VirtualGetBinIndex(uintptr_t Size)
    {
        uint32_t Value = 0;
        while (Size > kVirtualBinFactor)
        {
            Value += (kVirtualBinFactor >> 2);
            Size >>= 1;
        }
        return Value + (uint32_t)(Size >> 2);
    }

    __inline uintptr_t Allocator::VirtualGetPageIndex(uintptr_t MemoryPtr)
    {
        return VIRTUAL_PAGE_INDEX(MemoryPtr - mVirtualBaseAddress);
    }

    __inline unsigned int Allocator::VirtualGetSmallBinIndex(int Size)
    {
        if (Size <= 16)
            // small block bins 0,1,2 are reserved for block sizes 0,8,16
            return 0 + ((Size +  7) >> 3);
        else
            // and everything else
            return 1 + ((Size + 15) >> 4);
    }

    __inline int Allocator::VirtualGetHash(uintptr_t BlockPtr)
    {
        return ((BlockPtr - mVirtualBaseAddress) >> 8) % mVirtualHashTableSize;
    }
};

};
