﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#pragma once

#include <nn/nn_Common.h>
#include <nn/svc/svc_MemoryMapSelect.h>
#include <nn/svc/svc_Handle.h>
#include <nn/svc/svc_Result.h>
#include <nn/svc/svc_ThreadLocalRegion.h>
#include <cstring>
#include <type_traits>
#include <nn/util/util_BitPack.h>
#include <nn/nn_BitTypes.h>

namespace nn {
namespace svc {
namespace ipc {
inline Bit32* GetMessageBuffer()
{
    return GetThreadLocalRegion()->messageBuffer;
}

class MessageBuffer
{
public:
    class MessageHeader;
    MessageBuffer(Bit32* pMesssage, size_t size): m_p(pMesssage), m_Size(size) {}
    explicit MessageBuffer(Bit32* pMesssage): m_p(pMesssage), m_Size(sizeof(nn::svc::ThreadLocalRegion().messageBuffer)) {}

    int Set(int offset, const nn::util::BitPack32* pMesssage, size_t num) const
    {
        asm volatile ("":::"memory");
        __builtin_memcpy(m_p + offset, pMesssage, num * sizeof(nn::util::BitPack32));
        asm volatile ("":::"memory");
        return offset + num;
    }

    void Get(int offset, nn::util::BitPack32* pMesssage, size_t num) const
    {
        asm volatile ("":::"memory");
        __builtin_memcpy(pMesssage, m_p + offset, num * sizeof(nn::util::BitPack32));
    }

public:
    class MessageHeader
    {
    private:
        nn::util::BitPack32 m_Header[2];

        typedef nn::util::BitPack32::Field<0,               16,  Bit16> Tag;
        typedef nn::util::BitPack32::Field<Tag::Next,        4,  int> PointerNum;
        typedef nn::util::BitPack32::Field<PointerNum::Next, 4,  int> SendNum;
        typedef nn::util::BitPack32::Field<SendNum::Next,    4,  int> ReceiveNum;
        typedef nn::util::BitPack32::Field<ReceiveNum::Next, 4,  int> ExchangeNum;

        typedef nn::util::BitPack32::Field<0,                       10, int> RawNum;
        typedef nn::util::BitPack32::Field<RawNum::Next,             4, int> ReceiveListNum;
        typedef nn::util::BitPack32::Field<20,                      11, int> ReceiveListOffset;
        typedef nn::util::BitPack32::Field<ReceiveListOffset::Next,  1, int> SpecialNum;

        static const uint16_t TagNull     = 0;

    public:
        enum
        {
            ReceiveList_None            = 0,
            ReceiveList_ToMessageBuffer = 1,
            ReceiveList_ToOneBuffer     = 2,
            ReceiveList_BufferNumOffset = 2,
            ReceiveList_BufferNumMax    = 13,
        };

    public:
        bool IsNull() const
        {
            return GetTag() == TagNull;
        }

        MessageHeader(uint16_t tag, int specialNum, int pointerNum, int sendNum, int receiveNum, int exchangeNum, int rawNum, int receiveListNum)
        {
            m_Header[0] = { 0 };
            m_Header[0].Set<Tag>(tag);
            m_Header[0].Set<PointerNum>(pointerNum);
            m_Header[0].Set<SendNum>(sendNum);
            m_Header[0].Set<ReceiveNum>(receiveNum);
            m_Header[0].Set<ExchangeNum>(exchangeNum);

            m_Header[1] = { 0 };
            m_Header[1].Set<RawNum>(rawNum);
            m_Header[1].Set<ReceiveListNum>(receiveListNum);
            m_Header[1].Set<SpecialNum>(specialNum);
        }
        MessageHeader()
        {
            m_Header[0] = { 0 };
            m_Header[0].Set<Tag>(TagNull);
            m_Header[1] = { 0 };
        }
        explicit MessageHeader(const MessageBuffer& msg)
        {
            msg.Get(0, m_Header, sizeof(m_Header) / sizeof(*m_Header));
        }
        explicit MessageHeader(const Bit32* pMessage)
        {
            m_Header[0] = { pMessage[0] };
            m_Header[1] = { pMessage[1] };
        }
        uint16_t GetTag() const
        {
            return m_Header[0].Get<Tag>();
        }
        int GetPointerNum() const
        {
            return m_Header[0].Get<PointerNum>();
        }
        int GetSendNum() const
        {
            return m_Header[0].Get<SendNum>();
        }
        int GetReceiveNum() const
        {
            return m_Header[0].Get<ReceiveNum>();
        }
        int GetExchangeNum() const
        {
            return m_Header[0].Get<ExchangeNum>();
        }
        int GetRawNum() const
        {
            return m_Header[1].Get<RawNum>();
        }
        int GetReceiveListNum() const
        {
            return m_Header[1].Get<ReceiveListNum>();
        }
        int GetReceiveListOffset() const
        {
            return m_Header[1].Get<ReceiveListOffset>();
        }
        int GetSpecialNum() const
        {
            return m_Header[1].Get<SpecialNum>();
        }
        void SetReceiveListNum(int receiveListNum)
        {
            m_Header[1].Set<ReceiveListNum>(receiveListNum);
        }
        constexpr static size_t GetSize()
        {
            return sizeof(m_Header);
        }
        const nn::util::BitPack32* GetData() const
        {
            return m_Header;
        }
    };

    class SpecialHeader
    {
    private:
        nn::util::BitPack32 m_Header[1];
        int m_SpecialNum;

        typedef nn::util::BitPack32::Field<0,                    1, bool> Pid;
        typedef nn::util::BitPack32::Field<Pid::Next,            4,  int> CopyHandleNum;
        typedef nn::util::BitPack32::Field<CopyHandleNum::Next,  4,  int> MoveHandleNum;

    public:
        explicit SpecialHeader(bool processIdFlag, int copyNum, int moveNum): m_SpecialNum(1)
        {
            m_Header[0] = { 0 };
            m_Header[0].Set<Pid>(processIdFlag);
            m_Header[0].Set<CopyHandleNum>(copyNum);
            m_Header[0].Set<MoveHandleNum>(moveNum);
        }
        explicit SpecialHeader(const MessageBuffer& msg, const MessageHeader& header): m_SpecialNum(header.GetSpecialNum())
        {
            if (m_SpecialNum)
            {
                msg.Get(MessageHeader::GetSize() / sizeof(nn::util::BitPack32), m_Header, sizeof(m_Header) / sizeof(*m_Header));
            }
            else
            {
                m_Header[0] = { 0 };
            }
        }

        int GetMoveHandleNum() const
        {
            return m_Header[0].Get<MoveHandleNum>();
        }
        int GetCopyHandleNum() const
        {
            return m_Header[0].Get<CopyHandleNum>();
        }
        bool GetProcessIdFlag() const
        {
            return m_Header[0].Get<Pid>();
        }
        const nn::util::BitPack32* GetData() const
        {
            return m_Header;
        }
        size_t GetHeaderSize() const
        {
            if (m_SpecialNum == 0)
            {
                return 0;
            }
            else
            {
                return sizeof(m_Header);
            }
        }
        size_t GetDataSize() const
        {
            if (m_SpecialNum == 0)
            {
                return 0;
            }
            else
            {
                size_t size = 0;
                size += GetProcessIdFlag() ? sizeof(Bit64): 0;
                size += GetCopyHandleNum() * sizeof(nnHandle);
                size += GetMoveHandleNum() * sizeof(nnHandle);
                return size;
            }
        }
    };

    class PointerData
    {
    private:
        nn::util::BitPack32 m_Data[2];

        typedef nn::util::BitPack32::Field<0,                       4, int> PointerIndex;
        typedef nn::util::BitPack32::Field<6,                       3, uint32_t> PointerAddressHi2;
        typedef nn::util::BitPack32::Field<12,                      4, uint32_t> PointerAddressHi;
        typedef nn::util::BitPack32::Field<PointerAddressHi::Next, 16, uint32_t> PointerSize;

    public:
        PointerData(const void* pointer, size_t size, int index)
        {
            uint64_t address = reinterpret_cast<uintptr_t>(pointer);

            m_Data[0] = { 0 };
            m_Data[0].Set<PointerIndex>(index);
            m_Data[0].Set<PointerAddressHi>((address >> 32) & ((1u << PointerAddressHi::Width) - 1));
            m_Data[0].Set<PointerAddressHi2>(address >> 36);
            m_Data[0].Set<PointerSize>(size);
            m_Data[1] = { static_cast<uint32_t>(address) };
        }
        PointerData(int offset, const MessageBuffer& msg)
        {
            msg.Get(offset, m_Data, sizeof(m_Data) / sizeof(*m_Data));
        }
        uintptr_t GetPointerAddress() const
        {
            uint64_t address = (static_cast<uint64_t>((m_Data[0].Get<PointerAddressHi2>() << 4) | m_Data[0].Get<PointerAddressHi>()) << 32) | m_Data[1].storage;
            return static_cast<uintptr_t>(address);
        }
        int GetPointerIndex() const
        {
            return m_Data[0].Get<PointerIndex>();
        }
        size_t GetPointerSize() const
        {
            return m_Data[0].Get<PointerSize>();
        }
        const nn::util::BitPack32* GetData() const
        {
            return m_Data;
        }
        constexpr static size_t GetSize()
        {
            return sizeof(m_Data);
        }
    };

    class MapData
    {
    private:
        nn::util::BitPack32 m_Data[3];

        typedef nn::util::BitPack32::Field<0,               2, uint32_t> MapAttribute;
        typedef nn::util::BitPack32::Field<2,               3, uint32_t> MapAddressHi2;
        typedef nn::util::BitPack32::Field<24,              4, uint32_t> MapSizeHi;
        typedef nn::util::BitPack32::Field<MapSizeHi::Next, 4, uint32_t> MapAddressHi;

    public:
        enum MapTransferAttribute
        {
            MapTransferAttribute_Ipc,
            MapTransferAttribute_NonSecureIpc,
            MapTransferAttribute_NonDeviceIpc = 3,
        };
        MapData()
        {
            m_Data[0] = { 0 };
            m_Data[1] = { 0 };
            m_Data[2] = { 0 };
        }
        MapData(const void* pointer, size_t size, MapTransferAttribute attr)
        {
            uint64_t address = reinterpret_cast<uintptr_t>(pointer);
            uint64_t size64 = size;

            m_Data[0] = { static_cast<uint32_t>(size64) };
            m_Data[1] = { static_cast<uint32_t>(address) };
            m_Data[2] = { 0 };
            m_Data[2].Set<MapSizeHi>(static_cast<uint32_t>(size64 >> 32));
            m_Data[2].Set<MapAddressHi>(static_cast<uint32_t>(address >> 32) & ((1u << MapAddressHi::Width) - 1));
            m_Data[2].Set<MapAddressHi2>(static_cast<uint32_t>(address >> 36));
            m_Data[2].Set<MapAttribute>(static_cast<uint32_t>(attr));
        }
        // 互換性対策
        MapData(const void* pointer, size_t size)
            : MapData(pointer, size, MapTransferAttribute_Ipc)
        {
        }

        MapData(int offset, const MessageBuffer& msg)
        {
            msg.Get(offset, m_Data, sizeof(m_Data) / sizeof(*m_Data));
        }
        uintptr_t GetDataAddress() const
        {
            uint64_t address = (static_cast<uint64_t>((m_Data[2].Get<MapAddressHi2>() << 4) | m_Data[2].Get<MapAddressHi>()) << 32) | m_Data[1].storage;
            return static_cast<uintptr_t>(address);
        }
        size_t GetDataSize() const
        {
            uint64_t size = (static_cast<uint64_t>(m_Data[2].Get<MapSizeHi>()) << 32) | m_Data[0].storage;
            return static_cast<size_t>(size);
        }
        MapTransferAttribute GetAttribute() const
        {
            return static_cast<MapTransferAttribute>(m_Data[2].Get<MapAttribute>());
        }
        const nn::util::BitPack32* GetData() const
        {
            return m_Data;
        }
        constexpr static size_t GetSize()
        {
            return sizeof(m_Data);
        }
    };

    class ReceiveListEntry
    {
    private:
        nn::util::BitPack32 m_Data[2];

        typedef nn::util::BitPack32::Field<0,   7,  uint32_t> AddressHi;
        typedef nn::util::BitPack32::Field<16,  16, uint32_t> Size;

    public:
        ReceiveListEntry(void* pointer, size_t size)
        {
            uint64_t address = reinterpret_cast<uintptr_t>(pointer);

            m_Data[0] = { static_cast<uint32_t>(address) };
            m_Data[1] = { 0 };
            m_Data[1].Set<AddressHi>(static_cast<uint32_t>(address >> 32));
            m_Data[1].Set<Size>(size);
        }
        ReceiveListEntry(Bit32 data0, Bit32 data1)
        {
            m_Data[0] = { data0 };
            m_Data[1] = { data1 };
        }
        uintptr_t GetDataAddress() const
        {
            uint64_t address = (static_cast<uint64_t>(m_Data[1].Get<AddressHi>()) << 32) | m_Data[0].storage;
            return static_cast<uintptr_t>(address);
        }
        uintptr_t GetDataSize() const
        {
            return m_Data[1].Get<Size>();
        }
        const nn::util::BitPack32* GetData() const
        {
            return m_Data;
        }
        constexpr static size_t GetSize()
        {
            return sizeof(m_Data);
        }
    };

public:
    void SetNull() const
    {
        Set(nn::svc::ipc::MessageBuffer::MessageHeader());
    }
    int Set(const MessageHeader& header) const
    {
        __builtin_memcpy(m_p, header.GetData(), header.GetSize());
        return header.GetSize() / sizeof(Bit32);
    }
    int Set(const SpecialHeader& special) const
    {
        int offset = MessageHeader::GetSize() / sizeof(nn::util::BitPack32);
        __builtin_memcpy(m_p + offset, special.GetData(), special.GetHeaderSize());
        return offset + special.GetHeaderSize() / sizeof(Bit32);
    }
    int Set(int offset, const PointerData& pointerData) const
    {
        __builtin_memcpy(m_p + offset, pointerData.GetData(), pointerData.GetSize());
        return offset + pointerData.GetSize() / sizeof(Bit32);
    }
    int Set(int offset, const MapData& mapData) const
    {
        __builtin_memcpy(m_p + offset, mapData.GetData(), mapData.GetSize());
        return offset + mapData.GetSize() / sizeof(Bit32);
    }
    int SetHandle(int offset, const nn::svc::Handle& handle) const
    {
        __builtin_memcpy(m_p + offset, &handle, sizeof(handle));
        return offset + sizeof(handle) / sizeof(Bit32);
    }
    int SetProcessId(int offset, const Bit64 value) const
    {
        __builtin_memcpy(m_p + offset, &value, sizeof(value));
        return offset + sizeof(value) / sizeof(Bit32);
    }
    int Set(int offset, const uint32_t value) const
    {
        __builtin_memcpy(m_p + offset, &value, sizeof(value));
        return offset + sizeof(value) / sizeof(Bit32);
    }
    int Set(int offset, const ReceiveListEntry& rcvListEntry) const
    {
        __builtin_memcpy(m_p + offset, rcvListEntry.GetData(), rcvListEntry.GetSize());
        return offset + rcvListEntry.GetSize() / sizeof(Bit32);
    }

    void SetAsyncResult(Result result) const
    {
        int offset = Set(nn::svc::ipc::MessageBuffer::MessageHeader());
        Result::InnerType resultValue = result.GetInnerValueForDebug();
        __builtin_memcpy(m_p + offset, &resultValue, sizeof(resultValue));
    }
    Result GetAsyncResult() const
    {
        MessageHeader msgHeader(m_p);
        MessageHeader nullHeader;
        if (__builtin_memcmp(msgHeader.GetData(), nullHeader.GetData(), nullHeader.GetSize()) == 0)
        {
            return result::detail::ConstructResult(m_p[nullHeader.GetSize() / sizeof(m_p[0])]);
        }
        else
        {
            return ResultSuccess();
        }
    }

    Bit64 GetProcessId(int offset) const
    {
        Bit64 pid;
        __builtin_memcpy(&pid, m_p + offset, sizeof(Bit64));
        return pid;
    }

    nn::svc::Handle GetHandle(int offset) const
    {
        nn::svc::Handle handle(m_p[offset]);
        return handle;
    }

    template <typename T>
    const T& GetRaw(int offset) const
    {
            return *reinterpret_cast<const T*>(m_p + offset);
    }
    template <typename T>
    int SetRaw(int offset, const T& value)
    {
        *reinterpret_cast<T*>(m_p + offset) = value;
        return offset + (sizeof(T) + sizeof(Bit32) - 1) / sizeof(Bit32);
    }
    int SetRawArray(int offset, const void* pData, size_t length)
    {
        __builtin_memcpy(m_p + offset, pData, length);
        return offset + (length + sizeof(Bit32) - 1) / sizeof(Bit32);
    }
    void GetRawArray(int offset, void* pData, size_t length)
    {
        __builtin_memcpy(pData, m_p + offset, length);
    }

    size_t GetSize() const
    {
        return m_Size;
    }

    static int GetSpecialDataOffset(const MessageHeader &header, const SpecialHeader& special)
    {
        NN_UNUSED(header);

        size_t size = 0;
        size += MessageHeader::GetSize() / sizeof(Bit32);
        size += special.GetHeaderSize() / sizeof(Bit32);
        return size;
    }

    static int GetPointerDataOffset(const MessageHeader &header, const SpecialHeader& special)
    {
        size_t size = GetSpecialDataOffset(header, special);
        size += special.GetDataSize() / sizeof(Bit32);
        return size;
    }

    static int GetMapDataOffset(const MessageHeader &header, const SpecialHeader& special)
    {
        size_t size = GetSpecialDataOffset(header, special);
        size += special.GetDataSize() / sizeof(Bit32);
        size += header.GetPointerNum() * PointerData::GetSize() / sizeof(Bit32);
        return size;
    }

    static int GetRawDataOffset(const MessageHeader &header, const SpecialHeader& special)
    {
        size_t size = GetMapDataOffset(header, special);
        size += header.GetSendNum()     * MapData::GetSize() / sizeof(Bit32);
        size += header.GetReceiveNum()  * MapData::GetSize() / sizeof(Bit32);
        size += header.GetExchangeNum() * MapData::GetSize() / sizeof(Bit32);
        return size;
    }

    static int GetReceiveListOffset(const MessageHeader &header, const SpecialHeader& special)
    {
        if (header.GetReceiveListOffset())
        {
            return header.GetReceiveListOffset();
        }
        else
        {
            return GetRawDataOffset(header, special) + header.GetRawNum();
        }
    }

    static size_t GetMessageBufferSize(const MessageHeader &header, const SpecialHeader& special)
    {
        size_t msgSize = GetReceiveListOffset(header, special) * sizeof(Bit32);

        switch (header.GetReceiveListNum())
        {
        case MessageHeader::ReceiveList_None:
            break;
        case MessageHeader::ReceiveList_ToMessageBuffer:
            break;
        case MessageHeader::ReceiveList_ToOneBuffer:
            {
                msgSize += ReceiveListEntry::GetSize();
            }
            break;
        default:
            {
                msgSize += (header.GetReceiveListNum() - MessageHeader::ReceiveList_BufferNumOffset) * ReceiveListEntry::GetSize();
            }
            break;
        }

        return msgSize;
    }

private:
    Bit32* m_p;
    size_t m_Size;
};

} // end of namespace ipc
} // end of namespace svc
} // end of namespace nn

