﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#pragma once

#include <nn/nn_Common.h>
#include <nn/nn_BitTypes.h>
#include "../../../kern_KTaggedAddress.h"
#include "../../../kern_Utility.h"
#include "../kern_KPageTableDefinition.h"
#include <cstring>

namespace nn { namespace kern { namespace init { namespace ARMv8A {

class PteAttr
{
protected:
    Bit64 m_Attr;
    Bit64 GetBits(int32_t pos, int32_t length)     const { return (m_Attr >> pos) & ((1ull << length) - 1); }
    Bit64 ExtractBits(int32_t pos, int32_t length) const { return m_Attr & (((1ull << length) - 1) << pos); }
public:
    enum Permission : Bit64
    {
        //                                           uxn           pxn         perm
        Permission_KernelReadWrite =            ((1ul << 54) | (1ul << 53) | (0 << 6)),
        Permission_UserReadWrite =              ((1ul << 54) | (1ul << 53) | (1 << 6)),
        Permission_KernelRead =                 ((1ul << 54) | (1ul << 53) | (2 << 6)),
        Permission_UserRead =                   ((1ul << 54) | (1ul << 53) | (3 << 6)),

        Permission_KernelReadWriteExecute =     ((1ul << 54) | (0ul << 53) | (0 << 6)),
        Permission_KernelReadExecute =          ((1ul << 54) | (0ul << 53) | (2 << 6)),
        Permission_UserReadExecute =            ((0ul << 54) | (1ul << 53) | (3 << 6)),
    };
    enum Shared : Bit64
    {
        Shared_NonShared =      0ul << 8,
        Shared_OuterShared =    2ul << 8,
        Shared_InnerShared =    3ul << 8,
    };
    enum Attribute : Bit64
    {
        Attribute_nGnRnE =              0ul << 2,
        Attribute_nGnRE =               1ul << 2,
        Attribute_NormalMemory =        2ul << 2,
        Attribute_UncacheNormalMemory = 3ul << 2,
    };
    PteAttr(bool nonGlobal, Shared shared, Permission perm, Attribute attr) :
        m_Attr(
                static_cast<Bit64>(perm)                |
                (1ul << 10)                             | // AF は常に On
                (static_cast<Bit64>(nonGlobal) << 11)   |
                static_cast<Bit64>(shared)              |
                static_cast<Bit64>(attr)                |
                0)
    {}
    PteAttr(const PteAttr& pteAttr, Bit64 attr) : m_Attr(pteAttr.m_Attr | attr) {}
    PteAttr(Bit64 attr) : m_Attr(attr) {}

    bool  IsPage ()             const { return GetBits(0, 2) == 0x3; }
    bool  IsContiguous()        const { return GetBits(52, 1) == 0x1; }
    uintptr_t  GetBlock()       const { return ExtractBits(21, 27); }
    uintptr_t  GetTable()       const { return ExtractBits(12, 36); }
    uint32_t GetUxn()           const { return GetBits(54, 1); }
    uint32_t GetPxn()           const { return GetBits(53, 1); }
    uint32_t GetContiguous()    const { return GetBits(52, 1); }
    uint32_t GetNonGlobal()     const { return GetBits(11, 1); }
    uint32_t GetAf()            const { return GetBits(10, 1); }
    uint32_t GetSh()            const { return GetBits(8, 2); }
    uint32_t GetAp()            const { return GetBits(6, 2); }
    uint32_t GetNs()            const { return GetBits(5, 1); }
    uint32_t GetAttrIndx()      const { return GetBits(2, 3); }
};

class L1Entry : public PteAttr
{
public:
    L1Entry(KPhysicalAddress phys, const PteAttr& pteAttr, bool isContiguous):
        PteAttr(pteAttr, ((static_cast<Bit64>(isContiguous) << 52) | GetAsInteger(phys) | 1))
    {}
    L1Entry(KPhysicalAddress phys, bool uxn, bool pxn):
        PteAttr((static_cast<Bit64>(0x1) << 61) | (static_cast<Bit64>(uxn) << 60) | (static_cast<Bit64>(pxn) << 59) | GetAsInteger(phys) | 3)
    {}
    bool  IsBlock()             const { return GetBits(0, 2) == 0x1; }
    bool  IsTable()             const { return GetBits(0, 2) == 0x3; }
    KPhysicalAddress GetBlock() const { return KPhysicalAddress(ExtractBits(30, 18)); }
    KPhysicalAddress GetTable() const { return KPhysicalAddress(ExtractBits(12, 36)); }
};

class L2Entry : public PteAttr
{
public:
    L2Entry(KPhysicalAddress phys, const PteAttr& pteAttr, bool isContiguous):
        PteAttr(pteAttr, ((static_cast<Bit64>(isContiguous) << 52) | GetAsInteger(phys) | 1))
    {}
    L2Entry(KPhysicalAddress phys, bool uxn, bool pxn):
        PteAttr((static_cast<Bit64>(0x1) << 61) | (static_cast<Bit64>(uxn) << 60) | (static_cast<Bit64>(pxn) << 59) | GetAsInteger(phys) | 3)
    {}
    bool  IsBlock()             const { return GetBits(0, 2) == 0x1; }
    bool  IsTable()             const { return GetBits(0, 2) == 0x3; }
    KPhysicalAddress GetBlock() const { return KPhysicalAddress(ExtractBits(21, 27)); }
    KPhysicalAddress GetTable() const { return KPhysicalAddress(ExtractBits(12, 36)); }
};

class L3Entry : public PteAttr
{
public:
    L3Entry(KPhysicalAddress phys, const PteAttr& pteAttr, bool isContiguous):
        PteAttr(pteAttr, ((static_cast<Bit64>(isContiguous) << 52) | GetAsInteger(phys) | 3))
    {}
    bool  IsPage()              const { return GetBits(0, 2) == 0x3; }
    KPhysicalAddress GetPage () const { return KPhysicalAddress(ExtractBits(12, 36)); }
};

class KPageTableBody
{
private:
    uintptr_t           m_Offset;   // ページテーブルアクセスに使用する KPhysicalAddress → KVirtualAddress のオフセット.
    KPhysicalAddress    m_L1Table;
    uint32_t            m_NumEntry;
    bool                m_IsKernel;

public:
    void Initialize(KPhysicalAddress table, uintptr_t offset, uint32_t numL1Entry, bool isKernel)
    {
        m_L1Table = table;
        m_Offset = offset;
        m_NumEntry = numL1Entry;
        m_IsKernel = isKernel;
    }
    class PageTableAllocator
    {
    public:
        virtual KPhysicalAddress Allocate() { return Null<KPhysicalAddress>(); }
        virtual void Free(KPhysicalAddress pa) { NN_UNUSED(pa); }
    };

    KPhysicalAddress GetPhysicalAddress(KProcessAddress vaddr) const
    {
        int l1Index = (GetAsInteger(vaddr) >> HW_MMU_L1_BLOCK_SHIFT) & (m_NumEntry - 1);
        L1Entry* pL1 = GetTypedPointer<L1Entry>(KVirtualAddress(GetAsInteger(m_L1Table) + m_Offset)) + l1Index;

        if (pL1->IsBlock())
        {
            return pL1->GetBlock() + (GetAsInteger(vaddr) & (L1BlockSize - 1));
        }

        if (pL1->IsTable())
        {
            KPhysicalAddress l2Table = pL1->GetTable();
            int l2Index = (GetAsInteger(vaddr) >> HW_MMU_L2_BLOCK_SHIFT) & (HW_MMU_NUM_PTE - 1);
            L2Entry* pL2 = GetTypedPointer<L2Entry>(KVirtualAddress(GetAsInteger(l2Table) + m_Offset)) + l2Index;

            if (pL2->IsBlock())
            {
                return pL2->GetBlock() + (GetAsInteger(vaddr) & (L2BlockSize - 1));
            }

            if (pL2->IsTable())
            {
                KPhysicalAddress l3Table = pL2->GetTable();
                int l3Index = (GetAsInteger(vaddr) >> HW_MMU_PAGE_SHIFT) & (HW_MMU_NUM_PTE - 1);
                L3Entry* pL3 = GetTypedPointer<L3Entry>(KVirtualAddress(GetAsInteger(l3Table) + m_Offset)) + l3Index;

                if (pL3->IsPage())
                {
                    return pL3->GetPage() + (GetAsInteger(vaddr) & (PageSize - 1));
                }
            }
        }
        for (;;) {}
        return Null<KPhysicalAddress>();
    }

    bool IsFree(KProcessAddress addr, size_t size) const
    {
        KProcessAddress end = addr + size;
        while (addr < end)
        {
            int l1Index = (GetAsInteger(addr) >> HW_MMU_L1_BLOCK_SHIFT) & (m_NumEntry - 1);
            L1Entry* pL1 = GetTypedPointer<L1Entry>(KVirtualAddress(GetAsInteger(m_L1Table) + m_Offset)) + l1Index;

            if (pL1->IsBlock())
            {
                return false;
            }

            if (pL1->IsTable())
            {
                KPhysicalAddress l2Table = pL1->GetTable();
                int l2Index = (GetAsInteger(addr) >> HW_MMU_L2_BLOCK_SHIFT) & (HW_MMU_NUM_PTE - 1);
                L2Entry* pL2 = GetTypedPointer<L2Entry>(KVirtualAddress(GetAsInteger(l2Table) + m_Offset)) + l2Index;

                if (pL2->IsBlock())
                {
                    return false;
                }

                if (pL2->IsTable())
                {
                    KPhysicalAddress l3Table = pL2->GetTable();
                    int l3Index = (GetAsInteger(addr) >> HW_MMU_PAGE_SHIFT) & (HW_MMU_NUM_PTE - 1);
                    L3Entry* pL3 = GetTypedPointer<L3Entry>(KVirtualAddress(GetAsInteger(l3Table) + m_Offset)) + l3Index;

                    if (pL3->IsPage())
                    {
                        return false;
                    }

                    {
                        addr = RoundDown(addr + PageSize, PageSize);
                    }
                }
                else
                {
                    addr = RoundDown(addr + L2BlockSize, L2BlockSize);
                }
            }
            else
            {
                addr = RoundDown(addr + L1BlockSize, L1BlockSize);
            }

            if (addr == Null<KProcessAddress>())
            {
                return true;
            }
        }
        return true;
    }

    void Map(KProcessAddress addr, size_t size, KPhysicalAddress physAddr, const PteAttr& attr, PageTableAllocator* pAllocator) const
    {
        size_t leftSize = size;

        while (leftSize > 0)
        {
            int l1Index = (GetAsInteger(addr) >> HW_MMU_L1_BLOCK_SHIFT) & (m_NumEntry - 1);
            L1Entry* pL1 = GetTypedPointer<L1Entry>(KVirtualAddress(GetAsInteger(m_L1Table) + m_Offset)) + l1Index;
            if ((GetAsInteger(addr) & (L1BlockSize - 1)) == 0 &&
                    (GetAsInteger(physAddr) & (L1BlockSize - 1)) == 0 &&
                    leftSize >= L1BlockSize)
            {
                *pL1 = L1Entry(physAddr, attr, false);
                addr += L1BlockSize;
                physAddr += L1BlockSize;
                leftSize -= L1BlockSize;
                continue;
            }

            if (!pL1->IsTable())
            {
                KPhysicalAddress pt = pAllocator->Allocate();
                std::memset(reinterpret_cast<void*>(GetAsInteger(pt) + m_Offset), 0, HW_MMU_PAGETABLE_SIZE);
                *pL1 = L1Entry(pt, m_IsKernel, true);
            }

            KPhysicalAddress l2Table = pL1->GetTable();
            int l2Index = (GetAsInteger(addr) >> HW_MMU_L2_BLOCK_SHIFT) & (HW_MMU_NUM_PTE - 1);
            L2Entry* pL2 = GetTypedPointer<L2Entry>(KVirtualAddress(GetAsInteger(l2Table) + m_Offset)) + l2Index;

            if ((GetAsInteger(addr) & (ContiguousL2BlockSize - 1)) == 0 &&
                    (GetAsInteger(physAddr) & (ContiguousL2BlockSize - 1)) == 0 &&
                    leftSize >= ContiguousL2BlockSize)
            {
                for (int i = 0; i < ContiguousL2BlockSize / L2BlockSize; i++)
                {
                    pL2[i] = L2Entry(physAddr, attr, true);
                    addr += L2BlockSize;
                    physAddr += L2BlockSize;
                    leftSize -= L2BlockSize;
                }
                continue;
            }

            if ((GetAsInteger(addr) & (L2BlockSize - 1)) == 0 &&
                    (GetAsInteger(physAddr) & (L2BlockSize - 1)) == 0 &&
                    leftSize >= L2BlockSize)
            {
                *pL2 = L2Entry(physAddr, attr, false);
                addr += L2BlockSize;
                physAddr += L2BlockSize;
                leftSize -= L2BlockSize;
                continue;
            }

            if (!pL2->IsTable())
            {
                KPhysicalAddress pt = pAllocator->Allocate();
                std::memset(reinterpret_cast<void*>(GetAsInteger(pt) + m_Offset), 0, HW_MMU_PAGETABLE_SIZE);
                *pL2 = L2Entry(pt, m_IsKernel, true);
            }

            KPhysicalAddress l3Table = pL2->GetTable();
            int l3Index = (GetAsInteger(addr) >> HW_MMU_PAGE_SHIFT) & (HW_MMU_NUM_PTE - 1);
            L3Entry* pL3 = GetTypedPointer<L3Entry>(KVirtualAddress(GetAsInteger(l3Table) + m_Offset)) + l3Index;

            if ((GetAsInteger(addr) & (ContiguousPageSize - 1)) == 0 &&
                    (GetAsInteger(physAddr) & (ContiguousPageSize - 1)) == 0 &&
                    leftSize >= ContiguousPageSize)
            {
                for (int i = 0; i < ContiguousPageSize / PageSize; i++)
                {
                    pL3[i] = L3Entry(physAddr, attr, true);
                    addr += PageSize;
                    physAddr += PageSize;
                    leftSize -= PageSize;
                }
                continue;
            }

            *pL3 = L3Entry(physAddr, attr, false);
            addr += PageSize;
            physAddr += PageSize;
            leftSize -= PageSize;
        }
    }

private:
    enum
    {
        L1BlockSize           = ( 1 << HW_MMU_L1_BLOCK_SHIFT),
        ContiguousL2BlockSize = (16 << HW_MMU_L2_BLOCK_SHIFT),
        L2BlockSize           = ( 1 << HW_MMU_L2_BLOCK_SHIFT),
        ContiguousPageSize    = (16 << HW_MMU_PAGE_SHIFT),
        PageSize              = ( 1 << HW_MMU_PAGE_SHIFT),
    };
};

}}}}

