﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/jitsrv/jitsrv_Service.h>
#include <nn/jitsrv/jitsrv_JitServices.sfdl.h>
#include <nn/jit/plugin/jit_DllInterface.h>

#include <nn/nn_Common.h>
#include <nn/nn_Abort.h>
#include <nn/nn_SdkAssert.h>
#include <nn/nn_SdkLog.h>
#include <utility>
#include <type_traits>
#include <memory>
#include <nn/result/result_HandlingUtility.h>
#include <nn/util/util_ScopeExit.h>
#include <nn/util/util_IntUtil.h>
#include <nn/sf/sf_ObjectFactory.h>
#include <nn/init/init_Malloc.h>
#include <nn/jit/jit_Result.h>
#include <nn/settings/fwdbg/settings_SettingsGetterApi.h>
#include <nn/ro.h>

#include <nn/svc/svc_Base.h>

#include "jitsrv_DetailUtility.h"
#include "jitsrv_DllPlugin.h"

namespace nn { namespace diag { namespace detail {

void CallAllLogObserver(const LogMetaData& logMetaData, const LogBody& logBody) NN_NOEXCEPT;

}}}

namespace nn { namespace jitsrv {

class JitEnvironmentBaseImpl
{
private:

    detail::AslrAllocator m_AslrAllocator{svc::PSEUDO_HANDLE_CURRENT_PROCESS};

    detail::MemoryHolder<detail::CodeMemory> m_RxMemoryHolder;
    detail::MemoryHolder<detail::CodeMemory> m_RoMemoryHolder;
    jit::JitEnvironmentInfo m_EnvironmentInfo{};

    class PluginDllBuffer
    {
    private:
        void* m_P = nullptr;
        size_t m_Size = 0;
    public:
        void Allocate(size_t size) NN_NOEXCEPT
        {
            this->m_P = init::GetAllocator()->Allocate(size, os::MemoryPageSize);
            this->m_Size = size;
        }
        ~PluginDllBuffer() NN_NOEXCEPT
        {
            if (m_P)
            {
                init::GetAllocator()->Free(m_P);
            }
        }
        void* GetAddress() const NN_NOEXCEPT
        {
            return m_P;
        }
        size_t GetSize() const NN_NOEXCEPT
        {
            return m_Size;
        }
    };

    PluginDllBuffer m_PluginDllBuffer;
    DllPlugin m_Plugin;
    detail::MemoryHolder<detail::WorkMemory> m_WorkMemoryHolder;

    virtual Result OnLoad() NN_NOEXCEPT
    {
        NN_RESULT_SUCCESS;
    }

protected:

    virtual ~JitEnvironmentBaseImpl() = default;

    class Caller
    {
    private:

        JitEnvironmentBaseImpl* m_P;
        bool m_GeneratesCode;
        jit::plugin::JitPluginEnvironment m_Environment;

    public:

        explicit Caller(JitEnvironmentBaseImpl* p, bool generatesCode) NN_NOEXCEPT
            : m_P(p)
            , m_GeneratesCode(generatesCode)
        {
            m_Environment.environmentInfo = p->m_EnvironmentInfo;
            if (m_GeneratesCode)
            {
                m_P->m_RxMemoryHolder.BeginAccess();
                m_Environment.rxWritable = m_P->m_RxMemoryHolder.GetAddress();
                m_P->m_RoMemoryHolder.BeginAccess();
                m_Environment.roWritable = m_P->m_RoMemoryHolder.GetAddress();
            }
            else
            {
                m_Environment.rxWritable = nullptr;
                m_Environment.roWritable = nullptr;
            }
            m_P->m_WorkMemoryHolder.BeginAccess();
            m_Environment.workMemory = m_P->m_WorkMemoryHolder.GetAddress();
            m_Environment.workMemorySize = m_P->m_WorkMemoryHolder.GetBody().GetSize();
        }

        const jit::plugin::JitPluginEnvironment GetEnvironment() const NN_NOEXCEPT
        {
            return m_Environment;
        }

        void StoreDataCache(const jit::CodeRange& rx, const jit::CodeRange& ro) NN_NOEXCEPT
        {
            NN_SDK_REQUIRES(m_GeneratesCode);
            m_P->m_RxMemoryHolder.StoreDataCache(rx);
            m_P->m_RoMemoryHolder.StoreDataCache(ro);
        }

        ~Caller() NN_NOEXCEPT
        {
            m_P->m_WorkMemoryHolder.EndAccess();
            if (m_GeneratesCode)
            {
                m_P->m_RoMemoryHolder.EndAccess();
                m_P->m_RxMemoryHolder.EndAccess();
            }
        }

    };

public:

    Result Initialize(sf::NativeHandle processHandle, sf::NativeHandle rxHandle, uint64_t rxSize, sf::NativeHandle roHandle, uint64_t roSize, size_t pluginDllBufferSize) NN_NOEXCEPT
    {
        detail::AslrAllocator aslrAllocator{svc::Handle(processHandle.GetOsHandle())};
        detail::CodeMemory rxMemory;
        NN_RESULT_DO(rxMemory.Initialize(&aslrAllocator, std::move(rxHandle), rxSize, svc::MemoryPermission_ReadExecute));
        detail::CodeMemory roMemory;
        NN_RESULT_DO(roMemory.Initialize(&aslrAllocator, std::move(roHandle), roSize, svc::MemoryPermission_Read));

        m_PluginDllBuffer.Allocate(pluginDllBufferSize);
        NN_RESULT_THROW_UNLESS(m_PluginDllBuffer.GetAddress(), jit::ResultOutOfMemoryForDllLoad());

        m_RxMemoryHolder.GetBody().swap(rxMemory);
        m_RxMemoryHolder.Initialize(&m_AslrAllocator);
        m_RoMemoryHolder.GetBody().swap(roMemory);
        m_RoMemoryHolder.Initialize(&m_AslrAllocator);
        const auto& rx = m_RxMemoryHolder.GetBody();
        m_EnvironmentInfo.rxCodeAddress = rx.GetOwnerAddress();
        m_EnvironmentInfo.rxCodeSize = rx.GetSize();
        const auto& ro = m_RxMemoryHolder.GetBody();
        m_EnvironmentInfo.roCodeAddress = ro.GetOwnerAddress();
        m_EnvironmentInfo.roCodeSize = ro.GetSize();
        NN_RESULT_SUCCESS;
    }

    const jit::JitEnvironmentInfo& GetEnvironmentInfo() const NN_NOEXCEPT
    {
        return m_EnvironmentInfo;
    }

    Result GetCodeAddress(sf::Out<uint64_t> pRxCodeAddress, sf::Out<uint64_t> pRoCodeAddress) NN_NOEXCEPT
    {
        *pRxCodeAddress = m_EnvironmentInfo.rxCodeAddress;
        *pRoCodeAddress = m_EnvironmentInfo.roCodeAddress;
        NN_RESULT_SUCCESS;
    }

    Result LoadPlugin(sf::InBuffer nrr, sf::InBuffer nro, sf::NativeHandle workMemoryHandle, uint64_t workMemorySize) NN_NOEXCEPT
    {
        auto success = false;

        detail::WorkMemory workMemory;
        NN_RESULT_DO(workMemory.Initialize(std::move(workMemoryHandle), workMemorySize));
        NN_RESULT_DO(m_Plugin.Initialize(nrr.GetPointerUnsafe(), nrr.GetSize(), nro.GetPointerUnsafe(), nro.GetSize(), m_PluginDllBuffer.GetAddress(), m_PluginDllBuffer.GetSize()));
        NN_RESULT_DO(m_Plugin.Load());
        NN_UTIL_SCOPE_EXIT
        {
            if (!success)
            {
                m_Plugin.Unload();
            }
        };

        // バージョンチェック
        {
            NN_JIT_PLUGIN_GET_SYMBOL(m_Plugin, getVersion, NN_JIT_PLUGIN_FUNCTION_NAME(GetVersion), false);
            auto version = getVersion();
            NN_RESULT_THROW_UNLESS(version <= jit::plugin::PluginVersion, jit::ResultPluginVersionMismatch());
        }

        // ResolveBasicSymbols
        {
            NN_JIT_PLUGIN_TRY_GET_SYMBOL(m_Plugin, resolveBasicSymbols, NN_JIT_PLUGIN_FUNCTION_NAME(ResolveBasicSymbols), false);
            if (resolveBasicSymbols)
            {
                resolveBasicSymbols([](const char* name)
                {
                    uintptr_t ret;
                    return ro::LookupSymbol(&ret, name).IsSuccess() ? ret : 0;
                });
            }
        }

        // diag
        {
            NN_JIT_PLUGIN_TRY_GET_SYMBOL(m_Plugin, setup, NN_JIT_PLUGIN_FUNCTION_NAME(SetupDiagnostics), true);
            if (setup)
            {
                using namespace diag;
                using namespace jit::plugin;
                DiagnosticsParameters p = {};
                p.lookupSymbol = [](const char* name)
                {
                    uintptr_t ret;
                    return ro::LookupSymbol(&ret, name).IsSuccess() ? ret : 0;
                };
                setup(DiagnosticsVersion, &p);
            }
        }

        auto codeMemorySecurity = jit::MemorySecurityMode_Default;
        auto workMemorySecurity = jit::MemorySecurityMode_Default;

        // フレームワークコンフィギュレーション
        {
            NN_JIT_PLUGIN_GET_SYMBOL(m_Plugin, configure, NN_JIT_PLUGIN_FUNCTION_NAME(Configure), false);
            if (nn::settings::fwdbg::IsDebugModeEnabled())
            {
                jit::plugin::ConfigureParameters p = {};
                p.codeMemorySecurity = jit::MemorySecurityMode_Default;
                p.workMemorySecurity = jit::MemorySecurityMode_Default;
                configure(&p);
                codeMemorySecurity = static_cast<jit::MemorySecurityMode>(p.codeMemorySecurity);
                workMemorySecurity = static_cast<jit::MemorySecurityMode>(p.workMemorySecurity);
            }
            else
            {
                configure(nullptr);
            }
        }

        NN_RESULT_DO(OnLoad());

        success = true;
        m_WorkMemoryHolder.GetBody().swap(workMemory);
        m_WorkMemoryHolder.Initialize(&m_AslrAllocator);

        m_RxMemoryHolder.Start(codeMemorySecurity);
        m_RoMemoryHolder.Start(codeMemorySecurity);
        m_WorkMemoryHolder.Start(workMemorySecurity);

        // OnPrepared
        {
            NN_JIT_PLUGIN_GET_SYMBOL(m_Plugin, onPrepared, NN_JIT_PLUGIN_FUNCTION_NAME(OnPrepared), false);
            Caller caller(this, false);
            onPrepared(caller.GetEnvironment());
        }

        NN_RESULT_SUCCESS;
    }

    DllPlugin& GetDllPlugin() NN_NOEXCEPT
    {
        return m_Plugin;
    }

};

class JitEnvironmentImpl
    : public JitEnvironmentBaseImpl
{
private:

    NN_JIT_PLUGIN_FUNCTION_POINTER_TYPE(Control) m_Control = nullptr;
    NN_JIT_PLUGIN_FUNCTION_POINTER_TYPE(GenerateCode) m_GenerateCode = nullptr;

    virtual Result OnLoad() NN_NOEXCEPT
    {
        NN_JIT_PLUGIN_GET_SYMBOL(GetDllPlugin(), control, NN_JIT_PLUGIN_FUNCTION_NAME(Control), false);
        this->m_Control = control;
        NN_JIT_PLUGIN_GET_SYMBOL(GetDllPlugin(), generateCode, NN_JIT_PLUGIN_FUNCTION_NAME(GenerateCode), false);
        this->m_GenerateCode = generateCode;
        NN_RESULT_SUCCESS;
    }

public:

    Result GenerateCode(sf::Out<int> pOut, sf::Out<jit::CodeRange> pGeneratedRx, sf::Out<jit::CodeRange> pGeneratedRo, const sf::InBuffer& source, uint64_t tag, const jit::CodeRange& rxBuffer, const jit::CodeRange& roBuffer, const Struct32& inData, uint32_t inDataSize, const sf::OutBuffer& outBuffer) NN_NOEXCEPT
    {
        NN_RESULT_THROW_UNLESS(m_GenerateCode, jit::ResultInvalidCall());
        NN_RESULT_THROW_UNLESS(rxBuffer.IsValid(GetEnvironmentInfo().rxCodeSize), jit::ResultInvalidCall());
        NN_RESULT_THROW_UNLESS(roBuffer.IsValid(GetEnvironmentInfo().roCodeSize), jit::ResultInvalidCall());

        jit::CodeRange rxOut{};
        rxOut.offset = rxBuffer.offset;
        jit::CodeRange roOut{};
        roOut.offset = roBuffer.offset;
        int ret;

        {
            Caller caller(this, true);
            auto pluginResult = m_GenerateCode(&ret, &rxOut, &roOut, caller.GetEnvironment(), tag, source.GetPointerUnsafe(), source.GetSize(), rxBuffer, roBuffer, &inData, inDataSize, outBuffer.GetPointerUnsafe(), outBuffer.GetSize());

            NN_RESULT_THROW_UNLESS(rxOut.IsValid(GetEnvironmentInfo().rxCodeSize), jit::ResultPluginInvalidRxOut());
            NN_RESULT_THROW_UNLESS(roOut.IsValid(GetEnvironmentInfo().roCodeSize), jit::ResultPluginInvalidRoOut());
            NN_RESULT_THROW_UNLESS(rxOut.IsInRange(rxBuffer), jit::ResultPluginInvalidRxOut());
            NN_RESULT_THROW_UNLESS(roOut.IsInRange(roBuffer), jit::ResultPluginInvalidRoOut());
            NN_RESULT_THROW_UNLESS(pluginResult == 0, jit::ResultPluginError());
            caller.StoreDataCache(rxOut, roOut);
        }

        *pOut = ret;
        *pGeneratedRx = rxOut;
        *pGeneratedRo = roOut;
        NN_RESULT_SUCCESS;
    }

    Result Control(sf::Out<int> pOut, uint64_t tag, const sf::InBuffer& inData, const sf::OutBuffer& outData) NN_NOEXCEPT
    {
        NN_RESULT_THROW_UNLESS(m_Control, jit::ResultInvalidCall());
        Caller caller(this, false);
        int ret;
        auto pluginResult = m_Control(&ret, caller.GetEnvironment(), tag, inData.GetPointerUnsafe(), inData.GetSize(), outData.GetPointerUnsafe(), outData.GetSize());
        NN_RESULT_THROW_UNLESS(pluginResult == 0, jit::ResultPluginUndefinedCall());
        *pOut = ret;
        NN_RESULT_SUCCESS;
    }

};

class JitServiceImpl
{
public:

    Result CreateJitEnvironment(sf::Out<sf::SharedPointer<IJitEnvironment>> pOut, sf::NativeHandle&& processHandle, sf::NativeHandle&& rxHandle, std::uint64_t rxSize, sf::NativeHandle&& roHandle, uint64_t roSize) NN_NOEXCEPT
    {
        auto p = sf::CreateSharedObjectEmplaced<IJitEnvironment, JitEnvironmentImpl>();
        NN_RESULT_DO(p.GetImpl().Initialize(std::move(processHandle), std::move(rxHandle), rxSize, std::move(roHandle), roSize, 1024 * 1024));
        *pOut = std::move(p);
        NN_RESULT_SUCCESS;
    }

};

sf::UnmanagedServiceObject<IJitService, JitServiceImpl> g_JitServiceImpl;

sf::SharedPointer<IJitService> GetJitService() NN_NOEXCEPT
{
    return g_JitServiceImpl.GetShared();
}

}}
