﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/nn_Common.h>
#include <nn/nn_BitTypes.h>
#include "../../kern_Platform.h"
#include "kern_KPageTableBody.h"
#include "../../kern_Utility.h"

namespace nn { namespace kern { namespace ARMv8A {
namespace
{
struct Pte64
{
    Bit64 value;

    explicit Pte64(Bit64 x): value(x) {}
    Bit64 GetBits(int32_t pos, int32_t length) const { return (value >> pos) & ((1ull << length) - 1); }
    Bit64 ExtractBits(int32_t pos, int32_t length) const { return value & (((1ull << length) - 1) << pos); }
};

struct VirtualAddress : public Pte64
{
    explicit VirtualAddress(uintptr_t address) : Pte64(address) {}
    uintptr_t GetPageOffset()               const { return ExtractBits( 0, 12); }
    uintptr_t GetContiguousPageOffset()     const { return ExtractBits( 0, 16); }
    uintptr_t GetL2BlockOffset()            const { return ExtractBits( 0, 21); }
    uintptr_t GetContiguousL2BlockOffset()  const { return ExtractBits( 0, 25); }
    uintptr_t GetL1BlockOffset()            const { return ExtractBits( 0, 30); }
    uintptr_t GetContiguousL1BlockOffset()  const { return ExtractBits( 0, 34); }
    size_t GetL3Index()                     const { return GetBits    (12,  9); }
    size_t GetL2Index()                     const { return GetBits    (21,  9); }
    size_t GetL1Index()                     const { return GetBits    (30,  9); }
    size_t GetL0Index()                     const { return GetBits    (39,  9); }
};
struct L1Entry : public Pte64
{
    explicit L1Entry(uintptr_t entry) : Pte64(entry) {}
    bool IsBlock()              const { return GetBits(0, 2) == 0x1; }
    bool IsTable()              const { return GetBits(0, 2) == 0x3; }
    bool IsContiguous()         const { return GetBits(52, 1) == 0x1; }
    Bit64 GetBlock()            const { return ExtractBits(30, 18); }
    Bit64 GetTable()            const { return ExtractBits(12, 36); }
    uint32_t GetUxn()           const { return GetBits(54, 1); }
    uint32_t GetPxn()           const { return GetBits(53, 1); }
    uint32_t GetContiguous()    const { return GetBits(52, 1); }
    uint32_t GetNonGlobal()     const { return GetBits(11, 1); }
    uint32_t GetAf()            const { return GetBits(10, 1); }
    uint32_t GetSh()            const { return GetBits(8, 2); }
    uint32_t GetAp()            const { return GetBits(6, 2); }
    uint32_t GetNs()            const { return GetBits(5, 1); }
    uint32_t GetAttrIndx()      const { return GetBits(2, 3); }
};
struct L2Entry : public Pte64
{
    explicit L2Entry(uintptr_t entry) : Pte64(entry) {}
    bool  IsBlock()             const { return GetBits(0, 2) == 0x1; }
    bool  IsTable()             const { return GetBits(0, 2) == 0x3; }
    bool  IsContiguous()        const { return GetBits(52, 1) == 0x1; }
    uintptr_t  GetBlock()       const { return ExtractBits(21, 27); }
    uintptr_t  GetTable()       const { return ExtractBits(12, 36); }
    uint32_t GetUxn()           const { return GetBits(54, 1); }
    uint32_t GetPxn()           const { return GetBits(53, 1); }
    uint32_t GetContiguous()    const { return GetBits(52, 1); }
    uint32_t GetNonGlobal()     const { return GetBits(11, 1); }
    uint32_t GetAf()            const { return GetBits(10, 1); }
    uint32_t GetSh()            const { return GetBits(8, 2); }
    uint32_t GetAp()            const { return GetBits(6, 2); }
    uint32_t GetNs()            const { return GetBits(5, 1); }
    uint32_t GetAttrIndx()      const { return GetBits(2, 3); }

};
struct L3Entry : public Pte64
{
    explicit L3Entry(uintptr_t entry) : Pte64(entry) {}
    bool  IsPage()              const { return (GetBits(0, 2) & 0x3) == 0x3; }
    bool  IsContiguous()        const { return GetBits(52, 1) == 0x1; }
    uintptr_t  GetPage()        const { return ExtractBits(12, 36); }
    uint32_t GetUxn()           const { return GetBits(54, 1); }
    uint32_t GetPxn()           const { return GetBits(53, 1); }
    uint32_t GetContiguous()    const { return GetBits(52, 1); }
    uint32_t GetNonGlobal()     const { return GetBits(11, 1); }
    uint32_t GetAf()            const { return GetBits(10, 1); }
    uint32_t GetSh()            const { return GetBits(8, 2); }
    uint32_t GetAp()            const { return GetBits(6, 2); }
    uint32_t GetNs()            const { return GetBits(5, 1); }
    uint32_t GetAttrIndx()      const { return GetBits(2, 3); }
};

const Bit64 UNMAPPED_ENTRY = 0;
}

bool KPageTableBody::GetAddrFromL3Entry(
        size_t*             pNumPages,
        KPhysicalAddress*   pOut,
        Bit64            entryValue,
        KProcessAddress     srcAddress
        )
{
    const L3Entry entry(entryValue);
    const VirtualAddress srcAddr(GetAsInteger(srcAddress));

    if (entry.IsPage())
    {
        if (entry.IsContiguous())
        {
            *pNumPages = (16 << (9 * 0));
        }
        else
        {
            *pNumPages = (1 << (9 * 0));
        }
        *pOut = entry.GetPage() + srcAddr.GetPageOffset();
        return true;
    }
    else
    {
        *pOut   = 0;
        *pNumPages = (1 << (9 * 0));
        return false;
    }
}

bool KPageTableBody::GetAddrFromL2Entry(
        size_t*             pNumPages,
        KPhysicalAddress*   pOut,
        const Bit64**    ppL3,
        Bit64            entryValue,
        KProcessAddress     srcAddress
        )
{
    const L2Entry entry(entryValue);
    const VirtualAddress srcAddr(GetAsInteger(srcAddress));

    if (entry.IsBlock())
    {
        if (entry.IsContiguous())
        {
            *pNumPages = (16 << (9 * 1));
        }
        else
        {
            *pNumPages = (1 << (9 * 1));
        }
        *pOut = entry.GetBlock() + srcAddr.GetL2BlockOffset();
        *ppL3 = NULL;
        return true;
    }
    else if (entry.IsTable())
    {
        const Bit64* pTable = GetTypedPointer<Bit64>(GetPageTableVirtualAddress(entry.GetTable()));
        const size_t l3Index = srcAddr.GetL3Index();
        *ppL3 = &pTable[l3Index];
        return GetAddrFromL3Entry(pNumPages, pOut, **ppL3, srcAddress);
    }
    else
    {
        *ppL3   = NULL;
        *pOut   = 0;
        *pNumPages = (1 << (9 * 1));
        return false;
    }
}

bool KPageTableBody::GetAddrFromL1Entry(
        size_t*             pNumPages,
        KPhysicalAddress*   pOut,
        const Bit64**    ppL2,
        const Bit64**    ppL3,
        Bit64            entryValue,
        KProcessAddress     srcAddress
        )
{
    const L1Entry entry(entryValue);
    const VirtualAddress srcAddr(GetAsInteger(srcAddress));

    if (entry.IsBlock())
    {
        if (entry.IsContiguous())
        {
            *pNumPages = (16 << (9 * 2));
        }
        else
        {
            *pNumPages = (1 << (9 * 2));
        }
        *pOut = entry.GetBlock() + srcAddr.GetL1BlockOffset();
        *ppL2 = NULL;
        *ppL3 = NULL;
        return true;
    }
    else if (entry.IsTable())
    {
        const Bit64* pTable = GetTypedPointer<Bit64>(GetPageTableVirtualAddress(entry.GetTable()));
        const size_t l2Index = srcAddr.GetL2Index();
        *ppL2 = &pTable[l2Index];
        return GetAddrFromL2Entry(pNumPages, pOut, ppL3, **ppL2, srcAddress);
    }
    else
    {
        *ppL2   = NULL;
        *ppL3   = NULL;
        *pOut   = 0;
        *pNumPages = (1 << (9 * 2));
        return false;
    }
}

void KPageTableBody::InitializeForProcess(void* pTable, KProcessAddress begin, KProcessAddress end)
{
    m_pTable = static_cast<Bit64*>(pTable);
    m_IsKernel = false;
    m_NumEntry = RoundUp(end - begin, (1u << (12 + 9 * 2))) / (1u << (12 + 9 * 2));

    for (size_t i = 0; i < m_NumEntry; i++)
    {
        m_pTable[i] = UNMAPPED_ENTRY;
    }
}

void KPageTableBody::InitializeForKernel(void* pTable, KProcessAddress begin, KProcessAddress end)
{
    m_pTable = static_cast<Bit64*>(pTable);
    m_IsKernel = true;
    m_NumEntry = RoundUp(end - begin, (1u << (12 + 9 * 2))) / (1u << (12 + 9 * 2));
}

Bit64* KPageTableBody::Finalize()
{
    return m_pTable;
}

bool KPageTableBody::TraverseBegin(TraverseData* pOut, TraverseContext* pContext, KProcessAddress vaddr) const
{
    uint32_t numEntry = m_NumEntry;
    const VirtualAddress addr(GetAsInteger(vaddr));
    const size_t l0Index = addr.GetL0Index();
    const size_t l1Index = addr.GetL1Index();

    if (m_IsKernel)
    {
        if (l0Index != (NN_KERN_FINEST_PAGE_SIZE / sizeof(Bit64) - 1) || l1Index < (NN_KERN_FINEST_PAGE_SIZE / sizeof(Bit64) - numEntry))
        {
            pContext->pL1 = &m_pTable[numEntry];
            pContext->pL2 = NULL;
            pContext->pL3 = NULL;
            pOut->address = 0;
            pOut->size = ((1 << (9 * 2)) * NN_KERN_FINEST_PAGE_SIZE);
            return false;
        }
    }
    else
    {
        if (l0Index != 0 || l1Index >= numEntry)
        {
            pContext->pL1 = &m_pTable[numEntry];
            pContext->pL2 = NULL;
            pContext->pL3 = NULL;
            pOut->address = 0;
            pOut->size = ((1 << (9 * 2)) * NN_KERN_FINEST_PAGE_SIZE);
            return false;
        }
    }

    pContext->pL1 = &m_pTable[l1Index & (numEntry - 1)];
    size_t numPages;
    bool isValid = GetAddrFromL1Entry(&numPages, &pOut->address, &pContext->pL2, &pContext->pL3, *pContext->pL1, vaddr);
    pOut->size = numPages * NN_KERN_FINEST_PAGE_SIZE;

    switch (numPages)
    {
    case (16 << (9 * 2)):
        {
            pContext->pL1 += 16 - addr.GetContiguousL1BlockOffset() / ((1 << (9 * 2)) * NN_KERN_FINEST_PAGE_SIZE);
        }
        break;
    case (1 << (9 * 2)):
        {
            pContext->pL1 += 1;
        }
        break;
    case (16 << (9 * 1)):
        {
            pContext->pL1 += 1;
            pContext->pL2 += 16 - addr.GetContiguousL2BlockOffset() / ((1 << (9 * 1)) * NN_KERN_FINEST_PAGE_SIZE);
        }
        break;
    case (1 << (9 * 1)):
        {
            pContext->pL1 += 1;
            pContext->pL2 += 1;
        }
        break;
    case (16 << (9 * 0)):
        {
            pContext->pL1 += 1;
            pContext->pL2 += 1;
            pContext->pL3 += 16 - addr.GetContiguousPageOffset() / ((1 << (9 * 0)) * NN_KERN_FINEST_PAGE_SIZE);
        }
        break;
    case (1 << (9 * 0)):
        {
            pContext->pL1 += 1;
            pContext->pL2 += 1;
            pContext->pL3 += 1;
        }
        break;
    default:
        {
            NN_KERN_ASSERT(0);
        }
        break;
    }
    return isValid;
}

bool KPageTableBody::TraverseNext(TraverseData* pOut, TraverseContext* pContext) const
{
    size_t numPages;
    bool isValid = false;

    if (reinterpret_cast<uintptr_t>(pContext->pL3) % 0x1000 == 0x000)
    {
        if (reinterpret_cast<uintptr_t>(pContext->pL2) % 0x1000 == 0x000)
        {
            size_t l1Index = pContext->pL1 - m_pTable;
            if (l1Index < m_NumEntry)
            {
                isValid = GetAddrFromL1Entry(&numPages, &pOut->address, &pContext->pL2, &pContext->pL3, *pContext->pL1, 0);
            }
            else
            {
                pContext->pL1 = &m_pTable[m_NumEntry];
                pContext->pL2 = NULL;
                pContext->pL3 = NULL;
                pOut->address = 0;
                pOut->size = ((1 << (9 * 2)) * NN_KERN_FINEST_PAGE_SIZE);
                return false;
            }
            switch (numPages)
            {
            case (16 << (9 * 2)):
                {
                    pContext->pL1 += 16;
                }
                break;
            case (1 << (9 * 2)):
                {
                    pContext->pL1 += 1;
                }
                break;
            case (16 << (9 * 1)):
                {
                    pContext->pL1 += 1;
                    pContext->pL2 += 16;
                }
                break;
            case (1 << (9 * 1)):
                {
                    pContext->pL1 += 1;
                    pContext->pL2 += 1;
                }
                break;
            case (16 << (9 * 0)):
                {
                    pContext->pL1 += 1;
                    pContext->pL2 += 1;
                    pContext->pL3 += 16;
                }
                break;
            case (1 << (9 * 0)):
                {
                    pContext->pL1 += 1;
                    pContext->pL2 += 1;
                    pContext->pL3 += 1;
                }
                break;
            default:
                {
                    NN_KERN_ASSERT(0);
                }
                break;
            }
        }
        else
        {
            isValid = GetAddrFromL2Entry(&numPages, &pOut->address, &pContext->pL3, *pContext->pL2, 0);
            switch (numPages)
            {
            case (16 << (9 * 1)):
                {
                    pContext->pL2 += 16;
                }
                break;
            case (1 << (9 * 1)):
                {
                    pContext->pL2 += 1;
                }
                break;
            case (16 << (9 * 0)):
                {
                    pContext->pL2 += 1;
                    pContext->pL3 += 16;
                }
                break;
            case (1 << (9 * 0)):
                {
                    pContext->pL2 += 1;
                    pContext->pL3 += 1;
                }
                break;
            default:
                {
                    NN_KERN_ASSERT(0);
                }
                break;
            }
        }
    }
    else
    {
        isValid = GetAddrFromL3Entry(&numPages, &pOut->address, *pContext->pL3, 0);

        switch (numPages)
        {
        case (16 << (9 * 0)):
            {
                pContext->pL3 += 16;
            }
            break;
        case (1 << (9 * 0)):
            {
                pContext->pL3 += 1;
            }
            break;
        default:
            {
                NN_KERN_ASSERT(0);
            }
            break;
        }
    }

    pOut->size = numPages * NN_KERN_FINEST_PAGE_SIZE;
    return isValid;
}


bool KPageTableBody::GetPhysicalAddress(KPhysicalAddress* pOut, KProcessAddress vaddr) const
{
    uint32_t numEntry = m_NumEntry;
    const VirtualAddress srcAddr(GetAsInteger(vaddr));
    const size_t l0Index = srcAddr.GetL0Index();
    const size_t l1Index = srcAddr.GetL1Index();

    if (m_IsKernel)
    {
        if (l0Index != (NN_KERN_FINEST_PAGE_SIZE / sizeof(Bit64) - 1) || l1Index < (NN_KERN_FINEST_PAGE_SIZE / sizeof(Bit64) - numEntry))
        {
            return false;
        }
    }
    else
    {
        if (l0Index != 0 || l1Index >= numEntry)
        {
            return false;
        }
    }

    const L1Entry l1Entry(m_pTable[l1Index & (numEntry - 1)]);
    if (l1Entry.IsBlock())
    {
        *pOut = l1Entry.GetBlock() + srcAddr.GetL1BlockOffset();
        return true;
    }
    else if (l1Entry.IsTable())
    {
        const Bit64* pL2Table = GetTypedPointer<Bit64>(GetPageTableVirtualAddress(l1Entry.GetTable()));
        const L2Entry l2Entry(pL2Table[srcAddr.GetL2Index()]);
        if (l2Entry.IsBlock())
        {
            *pOut = l2Entry.GetBlock() + srcAddr.GetL2BlockOffset();
            return true;
        }
        else if (l2Entry.IsTable())
        {
            const Bit64* pL3Table = GetTypedPointer<Bit64>(GetPageTableVirtualAddress(l2Entry.GetTable()));
            const L3Entry l3Entry(pL3Table[srcAddr.GetL3Index()]);
            if (l3Entry.IsPage())
            {
                *pOut = l3Entry.GetPage() + srcAddr.GetPageOffset();
                return true;
            }
        }
    }

    return false;
}

void KPageTableBody::Dump(uintptr_t begin, size_t size) const
{
    NN_UNUSED(begin);
    NN_UNUSED(size);
#if defined NN_KERN_ENABLE_DUMP_PAGETABLE
    uintptr_t end = begin + size - 1;
    uintptr_t addr = begin;
    if (size == 0)
    {
        return;
    }
    Bit64 mair;
    HW_GET_MAIR_EL1(mair);


    bool notMapped = false;
    uintptr_t notMappedBegin = 0;

    while (addr < end)
    {
        VirtualAddress vaddr(addr);
        size_t l0Index = vaddr.GetL0Index();
        size_t l1Index = vaddr.GetL1Index();
        uint32_t numEntry = m_NumEntry;

        if (m_IsKernel)
        {
            if (l0Index != (NN_KERN_FINEST_PAGE_SIZE / sizeof(Bit64) - 1) || l1Index < (NN_KERN_FINEST_PAGE_SIZE / sizeof(Bit64) - numEntry))
            {
                return;
            }
        }
        else
        {
            if (l0Index != 0 || l1Index >= numEntry)
            {
                return;
            }
        }

        const Bit64 l1EntryValue = m_pTable[l1Index & (numEntry - 1)];
        const L1Entry l1Entry(l1EntryValue);
        if (l1Entry.IsBlock())
        {
            addr &= ~(0x40000000ul - 1);
            if (notMapped)
            {
                notMapped = false;
                NN_KERN_RELEASE_LOG("%p - %p: not mapped\n", notMappedBegin, addr - 1);
            }
            NN_KERN_RELEASE_LOG("%p: %016lx PA=%p SZ=1G UXN=%d PXN=%d Cont=%d nG=%d AF=%d SH=%x AP=%x NS=%d AttrIndx=%d\n",
                    addr, l1EntryValue, l1Entry.GetBlock(),
                    l1Entry.GetUxn(),
                    l1Entry.GetPxn(),
                    l1Entry.GetContiguous(),
                    l1Entry.GetNonGlobal(),
                    l1Entry.GetAf(),
                    l1Entry.GetSh(),
                    l1Entry.GetAp(),
                    l1Entry.GetNs(),
                    l1Entry.GetAttrIndx()
                    );
            addr += 0x40000000;
        }
        else if (l1Entry.IsTable())
        {
            const Bit64* pL2Table = GetTypedPointer<Bit64>(GetPageTableVirtualAddress(l1Entry.GetTable()));
            size_t l2Index = vaddr.GetL2Index();
            const Bit64 l2EntryValue = pL2Table[l2Index];
            const L2Entry l2Entry(l2EntryValue);

            if (l2Entry.IsBlock())
            {
                addr &= ~(0x00200000ul - 1);
                if (notMapped)
                {
                    notMapped = false;
                    NN_KERN_RELEASE_LOG("%p - %p: not mapped\n", notMappedBegin, addr - 1);
                }
                NN_KERN_RELEASE_LOG("%p: %016lx PA=%p SZ=2M UXN=%d PXN=%d Cont=%d nG=%d AF=%d SH=%d AP=%d NS=%d AttrIndx=%d\n",
                        addr, l2EntryValue, l2Entry.GetBlock(),
                        l2Entry.GetUxn(),
                        l2Entry.GetPxn(),
                        l2Entry.GetContiguous(),
                        l2Entry.GetNonGlobal(),
                        l2Entry.GetAf(),
                        l2Entry.GetSh(),
                        l2Entry.GetAp(),
                        l2Entry.GetNs(),
                        l2Entry.GetAttrIndx());
                addr += 0x00200000;
            }
            else if (l2Entry.IsTable())
            {
                const Bit64* pL3Table = GetTypedPointer<Bit64>(GetPageTableVirtualAddress(l2Entry.GetTable()));
                size_t l3Index = vaddr.GetL3Index();
                const Bit64 l3EntryValue = pL3Table[l3Index];
                const L3Entry l3Entry(l3EntryValue);

                if (l3Entry.IsPage())
                {
                    addr &= ~(0x1000ul - 1);
                    if (notMapped)
                    {
                        notMapped = false;
                        NN_KERN_RELEASE_LOG("%p - %p: not mapped\n", notMappedBegin, addr - 1);
                    }
                    NN_KERN_RELEASE_LOG("%p: %016lx PA=%p SZ=4K UXN=%d PXN=%d Cont=%d nG=%d AF=%d SH=%d AP=%d NS=%d AttrIndx=%d\n",
                            addr, l3EntryValue, l3Entry.GetPage(),
                            l3Entry.GetUxn(),
                            l3Entry.GetPxn(),
                            l3Entry.GetContiguous(),
                            l3Entry.GetNonGlobal(),
                            l3Entry.GetAf(),
                            l3Entry.GetSh(),
                            l3Entry.GetAp(),
                            l3Entry.GetNs(),
                            l3Entry.GetAttrIndx());
                    addr += 0x1000;
                }
                else
                {
                    addr &= ~(0x1000ul - 1);
                    if (!notMapped)
                    {
                        notMapped = true;
                        notMappedBegin = addr;
                    }
                    addr += 0x1000;
                }
            }
            else
            {
                addr &= ~(0x00200000ul - 1);
                if (!notMapped)
                {
                    notMapped = true;
                    notMappedBegin = addr;
                }
                addr += 0x00200000;
            }
        }
        else
        {
            addr &= ~(0x40000000ul - 1);
            if (!notMapped)
            {
                notMapped = true;
                notMappedBegin = addr;
            }
            addr += 0x40000000;
        }
    }
    if (notMapped)
    {
        notMapped = false;
        NN_KERN_RELEASE_LOG("%p - %p: not mapped\n", notMappedBegin, addr - 1);
    }
#endif
}

size_t KPageTableBody::CountPageTables() const
{
    size_t numTables = 0;
#ifdef NN_KERN_FOR_DEVELOPMENT
    numTables++;
    for (size_t l1Index = 0; l1Index < m_NumEntry; l1Index++)
    {
        const Bit64 l1EntryValue = m_pTable[l1Index & (m_NumEntry - 1)];
        const L1Entry l1Entry(l1EntryValue);
        if (l1Entry.IsTable())
        {
            numTables++;
            for (size_t l2Index = 0; l2Index < 512; l2Index++)
            {
                const Bit64* pL2Table = GetTypedPointer<Bit64>(GetPageTableVirtualAddress(l1Entry.GetTable()));
                const Bit64 l2EntryValue = pL2Table[l2Index];
                const L2Entry l2Entry(l2EntryValue);

                if (l2Entry.IsTable())
                {
                    numTables++;
                }
            }
        }
    }
#endif
    return numTables;
}

} } }

