﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include "sf_HipcEmulatedInProcessSession.h"

#include <nn/nn_Common.h>
#include <nn/nn_Abort.h>
#include <nn/nn_Result.h>
#include <nn/nn_SdkAssert.h>
#include <nn/result/result_HandlingUtility.h>

#include <nn/sf/sf_ISharedObject.h>
#include <nn/sf/sf_ObjectFactory.h>
#include <nn/sf/sf_NativeHandle.h>

#include <nn/sf/impl/sf_AllocationPolicies.h>
#include <nn/sf/impl/sf_ExpHeapAllocator.h>

#include <nn/os/os_MultipleWait.h>
#include <nn/os/os_SystemEvent.h>
#include <nn/os/os_Mutex.h>

#include <memory>
#include <mutex>
#include <nn/sf/hipc/detail/sf_HipcMessageBufferAccessor.h>
#include <nn/util/util_IntrusiveList.h>
#include <nn/util/util_Exchange.h>
#include <nn/util/util_BitUtil.h>

#include "sf_HipcEmulatedObjectBase.h"
#include "sf_HipcEmulatedSessionRequest.h"
#include "detail/sf_HipcHandleRegistrationInternal.h"

namespace nn { namespace sf { namespace hipc {

class HipcEmulatedInProcessSessionRequest
    : public HipcEmulatedObjectBase
    , public HipcEmulatedSessionRequest
    , public nn::util::IntrusiveListBaseNode<HipcEmulatedInProcessSessionRequest>
{
private:

    SharedPointer<ISharedObject> m_Parent;

    nn::os::Mutex m_Mutex;
    void* m_ClientMessageBuffer;
    size_t m_ClientMessageBufferSize;
    bool m_Received;
    bool m_Replyed;
    bool m_Processed;

    nn::os::SystemEvent m_ReplyedEvent;

protected:

    HipcEmulatedInProcessSessionRequest(ISharedObject* parent, void* clientMessageBuffer, size_t clientMessageBufferSize) NN_NOEXCEPT
        : m_Parent(parent, true)
        , m_Mutex(false)
        , m_ClientMessageBuffer(clientMessageBuffer)
        , m_ClientMessageBufferSize(clientMessageBufferSize)
        , m_ReplyedEvent(nn::os::EventClearMode_ManualClear, true)
        , m_Received(false)
        , m_Replyed(false)
    {
    }

    ~HipcEmulatedInProcessSessionRequest() NN_NOEXCEPT
    {
        if (!m_Received)
        {
            CloseMoveHandleOnMessage();
        }
    }

private:

    enum ReceiveBufferMode
    {
        ReceiveBufferMode_MessageBuffer,
        ReceiveBufferMode_SinglePointerBuffer,
        ReceiveBufferMode_MultiPointerBuffer,
    };

    struct ReceiveListEntry
    {
        cmif::PointerAndSize pointerAndSize;
    };

    void CopyReceiveListSingle(detail::HipcMessageBufferAccessor *pDstMessage, uintptr_t begin, uintptr_t end, const detail::HipcMessageHeaderInfo& headerInfo, const detail::HipcMessageBufferAccessor& srcMessage) NN_NOEXCEPT
    {
        auto p = begin;
        for (int i = 0; i < headerInfo.pointerCount; ++i)
        {
            NN_ABORT_UNLESS(p <= end);
            p = util::align_up(p, 16);
            NN_ABORT_UNLESS(p <= end);
            auto&& e = srcMessage.GetPointer(i);
            auto receiveIndex = srcMessage.GetPointerReceiveIndex(i);
            NN_ABORT_UNLESS(receiveIndex == i);
            NN_ABORT_UNLESS(e.size <= end - p);
            auto srcAddress = hipc::detail::UnregisterAddress(e.pointer);
            std::memcpy(reinterpret_cast<char*>(p), reinterpret_cast<const char*>(srcAddress), e.size);
            pDstMessage->SetPointer(i, hipc::detail::RegisterAddress(p), e.size, receiveIndex);
            p += e.size;
        }
    }

    void CopyReceiveListMulti(detail::HipcMessageBufferAccessor *pDstMessage, const ReceiveListEntry dstReceiveListEntries[], int dstReceiveListCount, const detail::HipcMessageHeaderInfo& headerInfo, const detail::HipcMessageBufferAccessor& srcMessage) NN_NOEXCEPT
    {
        NN_ABORT_UNLESS(headerInfo.pointerCount <= dstReceiveListCount);
        for (int i = 0; i < headerInfo.pointerCount; ++i)
        {
            auto p = dstReceiveListEntries[i].pointerAndSize.pointer;
            auto size = dstReceiveListEntries[i].pointerAndSize.size;
            auto&& e = srcMessage.GetPointer(i);
            auto receiveIndex = srcMessage.GetPointerReceiveIndex(i);
            NN_ABORT_UNLESS(receiveIndex == i);
            NN_ABORT_UNLESS(e.size <= size);
            auto srcAddress = hipc::detail::UnregisterAddress(e.pointer);
            std::memcpy(reinterpret_cast<char*>(p), reinterpret_cast<const char*>(srcAddress), e.size);
            pDstMessage->SetPointer(i, hipc::detail::RegisterAddress(p), e.size, receiveIndex);
        }
    }

    void CopyMessage(void* dstMessageBuffer, size_t dstMessageBufferSize, void* srcMessageBuffer, size_t srcMessageBufferSize, bool isRequest) NN_NOEXCEPT
    {
        // dstMessageBuffer 内の受信リストの保存
        detail::HipcMessageReceiveBufferMode dstReceiveBufferMode;
        ReceiveListEntry dstReceiveListEntries[13];
        int dstReceiveListCount;
        {
            detail::HipcMessageHeaderInfo dstHeaderInfo;
            detail::HipcMessageBufferAccessor message;
            NN_ABORT_UNLESS(message.ParseHeader(&dstHeaderInfo, dstMessageBuffer, dstMessageBufferSize));
            dstReceiveBufferMode = dstHeaderInfo.receiveBufferMode;
            dstReceiveListCount = dstHeaderInfo.receiveListCount;
            for (int i = 0; i < dstReceiveListCount; ++i)
            {
                auto e = message.GetReceiveList(i);
                dstReceiveListEntries[i].pointerAndSize.pointer = hipc::detail::UnregisterAddress(e.pointer);
                dstReceiveListEntries[i].pointerAndSize.size = e.size;
            }
        }

        detail::HipcMessageHeaderInfo headerInfo;
        detail::HipcMessageBufferAccessor srcMessage;
        NN_ABORT_UNLESS(srcMessage.ParseHeader(&headerInfo, srcMessageBuffer, srcMessageBufferSize));
        NN_ABORT_UNLESS(srcMessage.GetMessageSize() <= dstMessageBufferSize);
        detail::HipcMessageBufferAccessor dstMessage;
        NN_ABORT_UNLESS(dstMessage.SetupHeader(dstMessageBuffer, dstMessageBufferSize, headerInfo));

        dstMessage.SetTag(srcMessage.GetTag());

        if (headerInfo.hasPid)
        {
            // TODO: 何か適切な値
            auto processId = static_cast<uint64_t>(0);
            dstMessage.SetProcessId(processId);
        }

        for (int i = 0; i < headerInfo.copyHandleCount; ++i)
        {
            auto handle = srcMessage.GetCopyHandle(i);
            if (handle == detail::InvalidInternalHandleValue)
            {
                dstMessage.SetCopyHandle(i, detail::InvalidInternalHandleValue);
                continue;
            }
            dstMessage.SetCopyHandle(i, hipc::detail::DuplicateHandle(handle, false));
        }

        for (int i = 0; i < headerInfo.moveHandleCount; ++i)
        {
            auto handle = srcMessage.GetMoveHandle(i);
            if (handle == detail::InvalidInternalHandleValue)
            {
                dstMessage.SetMoveHandle(i, detail::InvalidInternalHandleValue);
                continue;
            }
            dstMessage.SetMoveHandle(i, hipc::detail::DuplicateHandle(handle, true));
        }

        if (headerInfo.pointerCount > 0)
        {
            switch (dstReceiveBufferMode)
            {
                case detail::HipcMessageReceiveBufferMode_None:
                {
                    NN_ABORT("[SF-Internal]");
                }
                case detail::HipcMessageReceiveBufferMode_MessageBuffer:
                {
                    auto begin = srcMessage.GetRawPointer() + headerInfo.rawDataByteSize;
                    auto end = reinterpret_cast<decltype(begin)>(srcMessageBuffer) + srcMessageBufferSize;
                    CopyReceiveListSingle(&dstMessage, begin, end, headerInfo, srcMessage);
                    break;
                }
                case detail::HipcMessageReceiveBufferMode_Single:
                {
                    auto p = dstReceiveListEntries[0].pointerAndSize.pointer;
                    auto size = dstReceiveListEntries[0].pointerAndSize.size;
                    CopyReceiveListSingle(&dstMessage, p, p + size, headerInfo, srcMessage);
                    break;
                }
                case detail::HipcMessageReceiveBufferMode_Multi:
                {
                    CopyReceiveListMulti(&dstMessage, dstReceiveListEntries, dstReceiveListCount, headerInfo, srcMessage);
                    break;
                }
                default: NN_UNEXPECTED_DEFAULT;
            }
        }

        if (isRequest)
        {
            for (int i = 0; i < headerInfo.sendCount; ++i)
            {
                auto p = srcMessage.GetSend(i);
                dstMessage.SetSend(i, p.pointer, p.size, p.mapTransferAttribute);
            }

            for (int i = 0; i < headerInfo.receiveCount; ++i)
            {
                auto p = srcMessage.GetReceive(i);
                dstMessage.SetReceive(i, p.pointer, p.size, p.mapTransferAttribute);
            }

            for (int i = 0; i < headerInfo.exchangeCount; ++i)
            {
                auto p = srcMessage.GetExchange(i);
                dstMessage.SetExchange(i, p.pointer, p.size, p.mapTransferAttribute);
            }
        }

        std::memcpy(
            reinterpret_cast<void*>(dstMessage.GetRawPointer()),
            reinterpret_cast<void*>(srcMessage.GetRawPointer()),
            headerInfo.rawDataByteSize);
    } // NOLINT(readability/fn_size)

    void CloseMoveHandleOnMessage() NN_NOEXCEPT
    {
        detail::HipcMessageHeaderInfo headerInfo;
        detail::HipcMessageBufferAccessor message;
        NN_ABORT_UNLESS(message.ParseHeader(&headerInfo, m_ClientMessageBuffer, m_ClientMessageBufferSize));
        for (int i = 0; i < headerInfo.moveHandleCount; ++i)
        {
            auto handle = message.GetMoveHandle(i);
            hipc::detail::DisposeHandle(handle);
        }
    }

public:

    void OnReceive(void* serverMessageBuffer, size_t serverMessageBufferSize) NN_NOEXCEPT
    {
        std::lock_guard<decltype(m_Mutex)> lk(m_Mutex);
        NN_SDK_ASSERT(!m_Received);
        CopyMessage(serverMessageBuffer, serverMessageBufferSize, m_ClientMessageBuffer, m_ClientMessageBufferSize, true);
        this->m_Received = true;
    }

    void Reply(void* serverMessageBuffer, size_t serverMessageBufferSize) NN_NOEXCEPT
    {
        std::lock_guard<decltype(m_Mutex)> lk(m_Mutex);
        NN_ABORT_UNLESS(!m_Replyed); // まだリプライされていない
        CopyMessage(m_ClientMessageBuffer, m_ClientMessageBufferSize, serverMessageBuffer, serverMessageBufferSize, false);
        this->m_Replyed = true;
        this->m_Processed = true;
        m_ReplyedEvent.Signal();
    }

    void CloseByServer() NN_NOEXCEPT
    {
        std::lock_guard<decltype(m_Mutex)> lk(m_Mutex);
        NN_ABORT_UNLESS(!m_Replyed); // まだリプライされていない
        this->m_Replyed = true;
        this->m_Processed = false;
        m_ReplyedEvent.Signal();
    }

    nn::os::SystemEvent* GetReplyedEvent() NN_NOEXCEPT
    {
        return &m_ReplyedEvent;
    }

    bool IsProcessed() NN_NOEXCEPT
    {
        return this->m_Processed;
    }

    virtual void AttachReplyEvent(nn::os::MultiWaitHolderType* pHolder) NN_NOEXCEPT NN_OVERRIDE
    {
        nn::os::InitializeMultiWaitHolder(pHolder, GetReplyedEvent()->GetBase());
    }

    virtual bool WaitRequest() NN_NOEXCEPT NN_OVERRIDE
    {
        GetReplyedEvent()->Wait();
        return IsProcessed();
    }

    virtual void CloseRequest() NN_NOEXCEPT NN_OVERRIDE
    {
        this->Release();
    }
};

namespace detail {

void OnChildSessionClosed(ISharedObject* port) NN_NOEXCEPT;

}

class HipcEmulatedInProcessSession
    : public HipcEmulatedObjectBase
{
private:

    class WaitableBool
    {
    private:

        bool m_Value;
        nn::os::SystemEvent m_Event;

    public:

        WaitableBool() NN_NOEXCEPT
            : m_Value(false)
            , m_Event(nn::os::EventClearMode_ManualClear, true)
        {
        }

        void Set(bool value) NN_NOEXCEPT
        {
            if (!(m_Value == value))
            {
                if (value)
                {
                    m_Event.Signal();
                }
                else
                {
                    m_Event.Clear();
                }
                this->m_Value = value;
            }
        }

        bool Get() const NN_NOEXCEPT
        {
            return m_Value;
        }

        nn::os::SystemEvent* GetEvent() NN_NOEXCEPT
        {
            return &m_Event;
        }

    };

    nn::os::Mutex m_QueueMutex;
    nn::util::IntrusiveList<HipcEmulatedInProcessSessionRequest, nn::util::IntrusiveListBaseNodeTraits<HipcEmulatedInProcessSessionRequest>> m_Queue;
    HipcEmulatedInProcessSessionRequest* m_CurrentRequest; // この変数は、内容が nullptr であるかどうかのみ m_QueueMutex の保護内
    bool m_ServerClosed;
    bool m_ClientClosed;
    WaitableBool m_CanReceive;
    SharedPointer<ISharedObject> m_ParentPort;

    void UpdateCanReceive() NN_NOEXCEPT
    {
        // assert(m_QueueMutex が現在のスレッドからロックされている);
        m_CanReceive.Set((!m_Queue.empty() || m_ClientClosed) && m_CurrentRequest == nullptr);
    }

protected:

    explicit HipcEmulatedInProcessSession(ISharedObject* parentPort) NN_NOEXCEPT
        : m_QueueMutex(false)
        , m_CurrentRequest(nullptr)
        , m_ServerClosed(false)
        , m_ClientClosed(false)
        , m_ParentPort(parentPort, true)
    {
    }

    ~HipcEmulatedInProcessSession() NN_NOEXCEPT
    {
        if (m_ParentPort)
        {
            detail::OnChildSessionClosed(m_ParentPort.Get());
        }
    }

public:

    static HipcEmulatedInProcessSession* Create(ISharedObject* parentPort) NN_NOEXCEPT
    {
        return Factory<HipcEmulatedInProcessSession>::Create(parentPort);
    }

    // for Client
    HipcEmulatedSessionRequest* CreateRequest(void* clientMessageBuffer, size_t clientMessageBufferSize) NN_NOEXCEPT
    {
        auto request = SharedPointer<HipcEmulatedInProcessSessionRequest>(Factory<HipcEmulatedInProcessSessionRequest>::Create(this, clientMessageBuffer, clientMessageBufferSize), false);
        {
            std::lock_guard<decltype(m_QueueMutex)> lk(m_QueueMutex);
            if (m_ServerClosed)
            {
                return nullptr;
            }
            request->AddReference();
            m_Queue.push_back(*request.Get());
            UpdateCanReceive();
        }
        return request.Detach();
    }

    // for Client
    void CloseByClient() NN_NOEXCEPT
    {
        std::lock_guard<decltype(m_QueueMutex)> lk(m_QueueMutex);
        if (m_ServerClosed)
        {
            return;
        }
        this->m_ClientClosed = true;
        UpdateCanReceive();
    }

    // for Server
    nn::os::SystemEvent* GetCanReceiveEvent() NN_NOEXCEPT
    {
        return m_CanReceive.GetEvent();
    }

    // for Server
    Result ReceiveRequest(bool* pClosed, void* messageBuffer, size_t messageBufferSize) NN_NOEXCEPT
    {
        for (;;)
        {
            {
                std::unique_lock<decltype(m_QueueMutex)> lk(m_QueueMutex);
                NN_ABORT_UNLESS(!m_ServerClosed);
                if (!m_CanReceive.Get())
                {
                    // receive できる状態でないときは、ロックを抜けて待機する
                    lk.unlock();
                    m_CanReceive.GetEvent()->Wait();
                    continue;
                }
                if (m_ClientClosed)
                {
                    *pClosed = true;
                    NN_RESULT_SUCCESS;
                }
                NN_SDK_ASSERT(m_CurrentRequest == nullptr);
                m_CurrentRequest = &m_Queue.front();
                m_Queue.pop_front();
                UpdateCanReceive();
            }
            m_CurrentRequest->OnReceive(messageBuffer, messageBufferSize);
            *pClosed = false;
            NN_RESULT_SUCCESS;
        }
    }

    // for Server
    void Reply(void* messageBuffer, size_t messageBufferSize) NN_NOEXCEPT
    {
        NN_ABORT_UNLESS(m_CurrentRequest);
        m_CurrentRequest->Reply(messageBuffer, messageBufferSize);
        m_CurrentRequest->Release();
        {
            std::lock_guard<decltype(m_QueueMutex)> lk(m_QueueMutex);
            NN_ABORT_UNLESS(!m_ServerClosed);
            this->m_CurrentRequest = nullptr;
            UpdateCanReceive();
        }
    }

    // for Server
    void CloseByServer() NN_NOEXCEPT
    {
        if (m_CurrentRequest)
        {
            m_CurrentRequest->CloseByServer();
            m_CurrentRequest->Release();
            this->m_CurrentRequest = nullptr;
        }
        decltype(m_Queue) queue;
        {
            // ロック範囲の最小化のため、一時キューに入れる
            std::lock_guard<decltype(m_QueueMutex)> lk(m_QueueMutex);
            NN_ABORT_UNLESS(!m_ServerClosed);
            this->m_ServerClosed = true;
            queue.splice(queue.end(), m_Queue);
        }
        while (!queue.empty())
        {
            auto request = &m_Queue.front();
            m_Queue.pop_front();
            request->CloseByServer();
        }
    }

    // utility
    static std::pair<
        HipcEmulatedInProcessServerSession*,
        HipcEmulatedInProcessClientSession*
    > CreateHipcEmulatedInProcessSessionPair(ISharedObject* parentPort) NN_NOEXCEPT;

};

HipcEmulatedInProcessClientSession::HipcEmulatedInProcessClientSession(HipcEmulatedInProcessSession* session) NN_NOEXCEPT
    : m_Session(session, true)
{
}

HipcEmulatedInProcessClientSession::~HipcEmulatedInProcessClientSession() NN_NOEXCEPT
{
    m_Session->CloseByClient();
}

HipcEmulatedSessionRequest* HipcEmulatedInProcessClientSession::CreateRequest(void* clientMessageBuffer, size_t clientMessageBufferSize) NN_NOEXCEPT
{
    return m_Session->CreateRequest(clientMessageBuffer, clientMessageBufferSize);
}

HipcEmulatedInProcessServerSession::HipcEmulatedInProcessServerSession(HipcEmulatedInProcessSession* session) NN_NOEXCEPT
    : m_Session(session, true)
{
}

HipcEmulatedInProcessServerSession::~HipcEmulatedInProcessServerSession() NN_NOEXCEPT
{
    m_Session->CloseByServer();
}

void HipcEmulatedInProcessServerSession::AttachReceiveEvent(nn::os::MultiWaitHolderType* pHolder) NN_NOEXCEPT
{
    nn::os::InitializeMultiWaitHolder(pHolder, m_Session->GetCanReceiveEvent()->GetBase());
}

Result HipcEmulatedInProcessServerSession::ReceiveRequest(bool* pClosed, void* messageBuffer, size_t messageBufferSize) NN_NOEXCEPT
{
    return m_Session->ReceiveRequest(pClosed, messageBuffer, messageBufferSize);
}

void HipcEmulatedInProcessServerSession::Reply(void* messageBuffer, size_t messageBufferSize) NN_NOEXCEPT
{
    m_Session->Reply(messageBuffer, messageBufferSize);
}

namespace {

    template <typename T>
    SharedPointer<T> MakeSharedAttached(T* p) NN_NOEXCEPT
    {
        return SharedPointer<T>(p, false);
    }

}

inline std::pair<
    HipcEmulatedInProcessServerSession*,
    HipcEmulatedInProcessClientSession*
> HipcEmulatedInProcessSession::CreateHipcEmulatedInProcessSessionPair(ISharedObject* parentPort) NN_NOEXCEPT
{
    auto session = MakeSharedAttached(HipcEmulatedInProcessSession::Create(parentPort));
    return std::make_pair(
        Factory<HipcEmulatedInProcessServerSession>::Create(session.Get()),
        Factory<HipcEmulatedInProcessClientSession>::Create(session.Get())
    );
}

std::pair<
    HipcEmulatedInProcessServerSession*,
    HipcEmulatedInProcessClientSession*
> CreateHipcEmulatedInProcessSessionPair(ISharedObject* parentPort) NN_NOEXCEPT
{
    return HipcEmulatedInProcessSession::CreateHipcEmulatedInProcessSessionPair(parentPort);
}

}}}
