﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#pragma once

#include <cstring>
#include <map>
#include <nn/dd.h>

#include <nn/usb/usb_Result.h>

#include "usb_Util.h"


namespace nn  { namespace usb {  namespace detail {

class IoVaFreeSegment
{
public:
    IoVaFreeSegment(uint32_t size, nn::dd::DeviceVirtualAddress ioVa)
        : m_Size(size)
        , m_IoVa(ioVa)
    {
    }

    size_t GetSize() {return m_Size;}
    nn::dd::DeviceVirtualAddress GetIoVa() {return m_IoVa;}

private:
    size_t                       m_Size;
    nn::dd::DeviceVirtualAddress m_IoVa;
};

class IoVaAllocatedSegment
{
public:
    IoVaAllocatedSegment(uint32_t size, uint32_t alignedOffset, nn::dd::DeviceVirtualAddress ioVa,
                         uint32_t requestedSize, int32_t procIndex, uint64_t procVa,
                         uint64_t context, int32_t tag)
        : m_Size(size)
        , m_AlignedOffset(alignedOffset)
        , m_IoVa(ioVa)
        , m_RequestedSize(requestedSize)
        , m_ProcVaIndex(procIndex)
        , m_ProcVa(procVa)
        , m_Context(context)
        , m_Tag(tag)
    {
    }

    size_t GetSize() {return m_Size;}
    size_t GetRequestedSize() {return m_RequestedSize;}
    nn::dd::DeviceVirtualAddress  GetBaseIoVa() {return m_IoVa;}
    nn::dd::DeviceVirtualAddress  GetIoVa() {return m_IoVa + m_AlignedOffset;}
    int32_t  GetProcVaIndex() {return m_ProcVaIndex;}
    uint64_t GetProcVa() {return m_ProcVa;}
    uint64_t GetContext(){return m_Context;}
    int32_t GetTag(){return m_Tag;}

private:
    size_t     m_Size;
    uint32_t   m_AlignedOffset;
    nn::dd::DeviceVirtualAddress m_IoVa;
    size_t     m_RequestedSize;
    int32_t    m_ProcVaIndex;
    uint64_t   m_ProcVa;
    uint64_t   m_Context;
    int32_t    m_Tag;
};

struct IoVaAllocationSummary
{
    size_t     requestedSize;
    int32_t    procVaIndex;
    uint64_t   procVa;
    nn::dd::DeviceVirtualAddress ioVa;
    uint64_t   context;
    int32_t    tag;
};

template<size_t MinimumSegmentSize, int32_t MaxProcVaLists>
class IoVaManager
{
public:
    IoVaManager() { }
    ~IoVaManager() { }
    void Initialize(nn::dd::DeviceVirtualAddress heapBase, uint32_t heapSize)
    {
        m_HeapBase = heapBase;
        m_HeapSize = heapSize;
        m_MapCount = 0;
        if(!m_FreeSizeList.empty()) NN_USB_ABORT("IoVaManager::Initialize() invalid state.\n");
        m_FreeSizeList.insert(std::make_pair(heapBase, IoVaFreeSegment(heapSize, heapBase)));
    }
    void Finalize()
    {
        m_FreeSizeList.clear();
        m_AllocatedIoVaList.clear();
        for(int32_t clientIndex=0; clientIndex < MaxProcVaLists; clientIndex++)
        {
            m_AllocatedProcVaList[clientIndex].clear();
        }
    }
    nn::dd::DeviceVirtualAddress Allocate(uint32_t requestedSize, uint32_t alignment, int32_t procVaIndex, uint64_t procVa,
                                          uint64_t context, int32_t tag)
    {
        nn::dd::DeviceVirtualAddress ioVa = 0ULL;
        uint32_t alignmentR = NN_USB_ROUNDUP_SIZE(alignment, MinimumSegmentSize);
        size_t sizeR = NN_USB_ROUNDUP_SIZE(requestedSize, alignmentR);

        if(!(procVaIndex < MaxProcVaLists))
        {
            NN_USB_LOG_ERROR("IoVaManager::Allocate() invalid procVaIndex=%d.\n",procVaIndex);
            return 0ULL;
        }

        // make sure it's not mapped yet
        if (GetIoVa(procVaIndex, context, procVa, &ioVa).IsSuccess())
        {
            NN_USB_ABORT("IoVaManager::Allocate() double map %p:%p\n", context, procVa);
            return 0ULL;
        }

        for (FreeSizeListType::iterator itr = m_FreeSizeList.lower_bound(sizeR); itr != m_FreeSizeList.end(); itr++)
        {
            IoVaFreeSegment allocSeg = itr->second;
            nn::dd::DeviceVirtualAddress ioVaAligned = NN_USB_ROUNDUP_SIZE(allocSeg.GetIoVa(), alignmentR);
            size_t offset = static_cast<uint32_t>(ioVaAligned - allocSeg.GetIoVa());
            size_t remainingSize = allocSeg.GetSize() - offset;
            if((allocSeg.GetSize() > offset) && (remainingSize >= sizeR))
            {
                size_t allocSegSize = allocSeg.GetSize();

                // remove from free list, it's now allocated
                m_FreeSizeList.erase(itr);

                // do we need to break up this segment?
                if((remainingSize - sizeR) > MinimumSegmentSize)
                {
                    // make new free segment out of surplus
                    size_t surplusSize = remainingSize - sizeR;
                    nn::dd::DeviceVirtualAddress surplusIoVa =  allocSeg.GetIoVa() + offset + sizeR;
                    m_FreeSizeList.insert(std::make_pair(surplusIoVa,
                                                         IoVaFreeSegment(surplusSize, surplusIoVa)));
                    allocSegSize = allocSegSize - surplusSize;
                }

                // create ioVa record of allocation
                if(m_AllocatedIoVaList.insert(std::make_pair(allocSeg.GetIoVa(),
                                                             IoVaAllocatedSegment(allocSegSize, offset, allocSeg.GetIoVa(),
                                                                                  requestedSize, procVaIndex, procVa,
                                                                                  context, tag))).second == false)
                {
                    NN_USB_ABORT("IoVaManager::Allocate() failed to insert into m_AllocatedIoVaList.\n");
                    return 0ULL;
                }

                // pass back allocated ioVa
                ioVa = allocSeg.GetIoVa() + offset;

                // create procVa record of allocation
                if(procVa != 0ULL)
                {
                    if(m_AllocatedProcVaList[procVaIndex].insert(std::make_pair(std::make_pair(context, procVa),
                                                                                IoVaAllocatedSegment(allocSegSize, offset, allocSeg.GetIoVa(),
                                                                                                     requestedSize, procVaIndex, procVa,
                                                                                                     context, tag))).second == false)
                    {
                        NN_USB_ABORT("IoVaManager::Allocate() failed to insert context=0x%llx procVa=0x%llx into m_AllocatedProcVaList.\n", context, procVa);
                        return 0ULL;
                    }
                }

                break;
            }
        }
        m_MapCount++;
        return ioVa;
    }
    Result Free(nn::dd::DeviceVirtualAddress ioVa)
    {
        if(m_AllocatedIoVaList.empty())
        {
            NN_USB_LOG_INFO("IoVaManager::Free(ioVa=%llx) failed to locate ioVa, empty list.\n", ioVa);
            return ResultIoVaError();
        }
        auto locatedIoVaMap = m_AllocatedIoVaList.find(ioVa);
        if(locatedIoVaMap == m_AllocatedIoVaList.end())
        {
            AllocatedIoVaListType::iterator itr = m_AllocatedIoVaList.upper_bound(ioVa);
            locatedIoVaMap = (itr==m_AllocatedIoVaList.begin()) ? (itr) : (--itr);
            if(ioVa < locatedIoVaMap->second.GetIoVa())
            {
                NN_USB_LOG_TRACE("IoVaManager::Free(ioVa=%llx) did not locate ioVa, out of bounds 1.\n",
                                 ioVa);
                return ResultIoVaError();
            }
            uint64_t offset = ioVa - locatedIoVaMap->second.GetIoVa();
            if(offset >= locatedIoVaMap->second.GetSize())
            {
                NN_USB_LOG_TRACE("IoVaManager::Free(ioVa=%llx) did not locate procVa, out of bounds 2.\n",
                                 ioVa);
                return ResultIoVaError();
            }
        }

        // make a local copy, since we are about to remove it from the map
        IoVaAllocatedSegment freedSegment = locatedIoVaMap->second;

        // release procVa from allocated map if provided
        if(freedSegment.GetProcVa()!=0ULL)
        {
            m_AllocatedProcVaList[freedSegment.GetProcVaIndex()].erase(
                std::make_pair(freedSegment.GetContext(),
                               freedSegment.GetProcVa())
            );
        }

        // release ioVa from allocated map
        m_AllocatedIoVaList.erase(locatedIoVaMap);

        // add iova to free map
        m_FreeSizeList.insert(std::make_pair(freedSegment.GetBaseIoVa(),
                                             IoVaFreeSegment(freedSegment.GetSize(),
                                                             freedSegment.GetBaseIoVa())));

        // Until we have logic to coalesce freed segments, re-init free list when last map is freed
        if(--m_MapCount == 0)
        {
            m_FreeSizeList.clear();
            m_FreeSizeList.insert(std::make_pair(m_HeapBase, IoVaFreeSegment(m_HeapSize, m_HeapBase)));
        }

        return ResultSuccess();
    }
    Result FreeByTag(int32_t tag, void* calloutContext,
                     void (*callback)(void* calloutContext,IoVaAllocationSummary& summary))
    {
        Result result = ResultSuccess();
        for (AllocatedIoVaListType::iterator itr = m_AllocatedIoVaList.begin(); itr != m_AllocatedIoVaList.end(); )
        {
            IoVaAllocatedSegment freedSegment = itr->second;
            itr++;
            if(freedSegment.GetTag() == tag)
            {
                IoVaAllocationSummary summary;
                summary.requestedSize = freedSegment.GetRequestedSize();
                summary.procVaIndex   = freedSegment.GetProcVaIndex();
                summary.procVa        = freedSegment.GetProcVa();
                summary.ioVa          = freedSegment.GetIoVa();
                summary.context       = freedSegment.GetContext();
                summary.tag           = freedSegment.GetTag();
                if(callback) (*callback)(calloutContext, summary);
                NN_USB_ABORT_UPON_ERROR(Free(freedSegment.GetBaseIoVa()));
            }
        }
        return result;
    }

    Result GetIoVa(int32_t procVaIndex, uint64_t context, uint64_t procVa,
                   nn::dd::DeviceVirtualAddress* pReturnedIoVa)
    {
        Result result;
        IoVaAllocationSummary summary;

        result = LookupByProcVa(procVaIndex, context, procVa, &summary);
        if (result.IsSuccess())
        {
            *pReturnedIoVa = summary.ioVa + (procVa - summary.procVa);
        }

        return result;
    }

    Result LookupByIoVa(nn::dd::DeviceVirtualAddress ioVa, IoVaAllocationSummary* pReturnedSummary)
    {
        Result result = ResultSuccess();
        uint64_t offset = 0ULL;
        if(m_AllocatedIoVaList.empty())
        {
            return ResultIoVaError();
        }
        auto locatedIoVaMap = m_AllocatedIoVaList.find(ioVa);
        if(locatedIoVaMap == m_AllocatedIoVaList.end())
        {
            AllocatedIoVaListType::iterator itr = m_AllocatedIoVaList.upper_bound(ioVa);
            locatedIoVaMap = (itr==m_AllocatedIoVaList.begin()) ? (itr) : (--itr);
            if(ioVa < locatedIoVaMap->second.GetIoVa())
            {
                //NN_USB_LOG_TRACE("IoVaManager::LookupByIoVa(ioVa=%llx) did not locate ioVa, out of bounds 1.\n", ioVa);
                return ResultIoVaError();
            }
            offset = ioVa - locatedIoVaMap->second.GetIoVa();
            if(offset >= locatedIoVaMap->second.GetSize())
            {
                //NN_USB_LOG_TRACE("IoVaManager::LookupByIoVa(ioVa=%llx) did not locate ioVa, out of bounds 2.\n", ioVa);
                return ResultIoVaError();
            }
        }
        if(pReturnedSummary!=NULL)
        {
            pReturnedSummary->requestedSize = locatedIoVaMap->second.GetRequestedSize();
            pReturnedSummary->procVaIndex   = locatedIoVaMap->second.GetProcVaIndex();
            pReturnedSummary->procVa        = locatedIoVaMap->second.GetProcVa();
            pReturnedSummary->ioVa          = locatedIoVaMap->second.GetIoVa();
            pReturnedSummary->context       = locatedIoVaMap->second.GetContext();
        }
        return result;
    }

    Result LookupByProcVa(int32_t procVaIndex, uint64_t context, uint64_t procVa, IoVaAllocationSummary* pReturnedSummary)
    {
        Result result = ResultSuccess();
        uint64_t offset = 0ULL;
        if(!(procVaIndex < MaxProcVaLists))
        {
            NN_USB_LOG_INFO("IoVaManager::LookupByProcVa(procVa=%llx, procVaIndex=%d) invalid procVaIndex\n",
                             procVa, procVaIndex);
            return ResultIoVaError();
        }
        if(m_AllocatedProcVaList[procVaIndex].empty())
        {
            NN_USB_LOG_TRACE("LookupByProcVa::GetIoVa(procVa=%llx, procVaIndex=%d) did not locate procVa, empty list.\n",
                             procVa, procVaIndex);
            return ResultIoVaError();
        }
        auto locatedProcVaMap = m_AllocatedProcVaList[procVaIndex].find(std::make_pair(context, procVa));
        if(locatedProcVaMap == m_AllocatedProcVaList[procVaIndex].end())
        {
            AllocatedProcVaListType::iterator itr = m_AllocatedProcVaList[procVaIndex].upper_bound(std::make_pair(context, procVa));

            if (itr == m_AllocatedProcVaList[procVaIndex].begin())
            {
                return ResultIoVaError();
            }

            locatedProcVaMap = --itr;

            if (locatedProcVaMap->second.GetContext() != context)
            {
                return ResultIoVaError();
            }

            offset = procVa - locatedProcVaMap->second.GetProcVa();
            if(offset >= locatedProcVaMap->second.GetRequestedSize())
            {
                return ResultIoVaError();
            }
        }
        if(pReturnedSummary!=NULL)
        {
            pReturnedSummary->requestedSize = locatedProcVaMap->second.GetRequestedSize();
            pReturnedSummary->procVaIndex   = locatedProcVaMap->second.GetProcVaIndex();
            pReturnedSummary->procVa        = locatedProcVaMap->second.GetProcVa();
            pReturnedSummary->ioVa          = locatedProcVaMap->second.GetIoVa();
            pReturnedSummary->context       = locatedProcVaMap->second.GetContext();
        }
        return result;
    }

    void DumpMaps(int32_t procVaIndex)
    {
        NN_USB_LOG_INFO("IoVa Heap %p + %x:\n", m_HeapBase, m_HeapSize);

        for (auto& pair : m_AllocatedProcVaList[procVaIndex])
        {
            auto& segment = pair.second;
            NN_USB_LOG_INFO("  %08x : %08x + %08x -> %08x | %08x + %08x\n",
                            segment.GetContext(), segment.GetProcVa(), segment.GetRequestedSize(),
                            segment.GetIoVa(), segment.GetBaseIoVa(), segment.GetSize());
        }
    }

private:
    typedef std::multimap<uint32_t, IoVaFreeSegment, std::less<uint32_t>,
        Allocator<std::pair<const uint32_t, IoVaFreeSegment>>> FreeSizeListType;
    typedef std::map<nn::dd::DeviceVirtualAddress, IoVaAllocatedSegment, std::less<nn::dd::DeviceVirtualAddress>,
        Allocator<std::pair<const nn::dd::DeviceVirtualAddress, IoVaAllocatedSegment>>> AllocatedIoVaListType;
    typedef std::map<std::pair<uint64_t, uint64_t>, IoVaAllocatedSegment, std::less<std::pair<uint64_t, uint64_t>>,
        Allocator<std::pair<const std::pair<uint64_t, uint64_t>, IoVaAllocatedSegment>>> AllocatedProcVaListType;
    FreeSizeListType             m_FreeSizeList;
    AllocatedIoVaListType        m_AllocatedIoVaList;
    AllocatedProcVaListType      m_AllocatedProcVaList[MaxProcVaLists];
    int32_t                      m_MapCount;
    nn::dd::DeviceVirtualAddress m_HeapBase;
    size_t                       m_HeapSize;
};


} // end of namespace detail
} // end of namespace usb
} // end of namespace nn
