﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/ro/ro_Api.h>
#include <nn/ro/ro_Result.h>
#include <nn/ro/ro_Types.h>
#include <nn/ro/detail/ro_Elf.h>
#include <nn/ro/detail/ro_NroHeader.h>
#include <nn/ro/detail/ro_NrrHeader.h>
#include <nn/ro/detail/ro_RoModule.h>
#include <nn/ro/detail/ro_RoInterface.h>
#include <nn/ro/detail/ro_PortName.h>
#include <nn/rocrt/rocrt.h>
#include <nn/os.h>
#include <nn/nn_SdkAssert.h>
#include <nn/nn_Abort.h>
#include <nn/util/util_BitUtil.h>
#include <nn/nn_SdkLog.h>
#include <nn/util/util_IntrusiveList.h>
#include <nn/svc/svc_Base.h>
#include <nn/os/os_MemoryFence.h>
#include <nn/svc/svc_Handle.h>
#include <nn/sf/sf_HipcClient.h>
#include <nn/sf/sf_ExpHeapAllocator.h>

#include <nn/os/os_NativeHandle.h>

#include <mutex>
#include <shared_mutex>
#include <pthread.h>

#include "detail/ro_ReaderWriterLock.h"

namespace nn { namespace ro {

namespace {

struct StaticMutex
{
    os::MutexType    m_Mutex;
    void lock() NN_NOEXCEPT
    {
        nn::os::LockMutex( &m_Mutex );
    }
    void unlock() NN_NOEXCEPT
    {
        nn::os::UnlockMutex( &m_Mutex );
    }
};
StaticMutex g_RegistrationListLock = { NN_OS_MUTEX_INITIALIZER(false) };
bool g_IsInitializedLock = false;

class RegistrationInfoTraits
{
private:
    friend class util::IntrusiveList<RegistrationInfo, RegistrationInfoTraits>;

    static util::IntrusiveListNode& GetNode(RegistrationInfo& ref) NN_NOEXCEPT
    {
        return ref._listNode;
    }

    static const util::IntrusiveListNode& GetNode(const RegistrationInfo& ref) NN_NOEXCEPT
    {
        return ref._listNode;
    }

    static RegistrationInfo& GetItem(util::IntrusiveListNode& node) NN_NOEXCEPT
    {
        return *reinterpret_cast<RegistrationInfo*>(reinterpret_cast<char*>(&node) - offsetof(RegistrationInfo, _listNode));
    }

    static const RegistrationInfo& GetItem(const util::IntrusiveListNode& node) NN_NOEXCEPT
    {
        return *reinterpret_cast<const RegistrationInfo*>(reinterpret_cast<const char*>(&node) - offsetof(RegistrationInfo, _listNode));
    }
};

typedef util::IntrusiveList <RegistrationInfo, RegistrationInfoTraits> RegistrationList;
RegistrationList g_RegistrationList;

void CheckBindFlag(int flag) NN_NOEXCEPT
{
    NN_UNUSED(flag);
    NN_SDK_ASSERT(((flag & BindFlag_Now) == BindFlag_Now) ||
                  ((flag & BindFlag_Lazy) == BindFlag_Lazy));

    NN_SDK_ASSERT((flag & (BindFlag_Now | BindFlag_Lazy)) != (BindFlag_Now | BindFlag_Lazy));
}


struct AllocatorTag
{
};
typedef nn::sf::ExpHeapStaticAllocator<1024, AllocatorTag> MyAllocator;
nn::sf::SharedPointer<detail::IRoInterface> g_RefRoInterface;
bool g_IsInitialized = false;

Result InitializeForRo() NN_NOEXCEPT
{
    NN_SDK_ASSERT( ! g_IsInitialized );

    MyAllocator::Initialize(nn::lmem::CreationOption_NoOption);

    auto result = nn::sf::CreateHipcProxyByName<nn::ro::detail::IRoInterface, MyAllocator::Policy>(
            &g_RefRoInterface, nn::ro::detail::PortNameForRo);
    if( result.IsSuccess() )
    {
        result = g_RefRoInterface->RegisterProcessHandle(0, nn::sf::NativeHandle(os::GetCurrentProcessHandle(), false));
        if (result.IsSuccess())
        {
            g_IsInitialized = true;
        }
        else
        {
            g_RefRoInterface.Reset();
            g_RefRoInterface = nullptr;
        }
    }

    return result;
}

void FinalizeForRo() NN_NOEXCEPT
{
    g_RefRoInterface.Reset();
    g_RefRoInterface = nullptr;
    g_IsInitialized = false;
}

} // namespace

void Initialize() NN_NOEXCEPT
{
    if( !g_IsInitializedLock )
    {
        nn::os::InitializeReaderWriterLock(&detail::g_RoLock.m_ReaderWriterLock);
        nn::os::FenceMemoryAnyAny(); // Lock の初期化後に LookupGlobalManual 関数が呼ばれることを保障するようにする
        g_IsInitializedLock = true;
        *nn::ro::detail::g_pLookupGlobalManualFunctionPointer = nn::ro::detail::LookupGlobalManual;
    }
    auto result = InitializeForRo();
    NN_ABORT_UNLESS_RESULT_SUCCESS(result);
}

void Finalize() NN_NOEXCEPT
{
    Result result;

    // Finalize の最中ロックを取らせないように最初にロックを確保する
    std::lock_guard<detail::StaticReaderWriterLock> writerLock(detail::g_RoLock);
    std::lock_guard<StaticMutex> registerScopedLock(g_RegistrationListLock);

    // NRO ファイルの解放
    while (!detail::g_pManualLoadList->empty())
    {
        detail::RoModule& module = detail::g_pManualLoadList->back();
        module.CallFini();
        detail::g_pManualLoadList->pop_back();
        result = g_RefRoInterface->UnmapManualLoadModuleMemory(0, module.GetBase());
        NN_ABORT_UNLESS_RESULT_SUCCESS(result);
    }

    // NRR ファイルの解放
    while (!g_RegistrationList.empty())
    {
        RegistrationInfo& info = g_RegistrationList.back();
        g_RegistrationList.pop_back();
        result = g_RefRoInterface->UnregisterModuleInfo(0, info._fileAddress);
        NN_ABORT_UNLESS_RESULT_SUCCESS(result);
    }

    FinalizeForRo();
}

Result  GetBufferSize(size_t* pOutSize, const void* pImage) NN_NOEXCEPT
{
    NN_SDK_ASSERT_NOT_NULL(pOutSize);
    NN_SDK_ASSERT_NOT_NULL(pImage);

    const char* ptr = reinterpret_cast<const char*>(pImage);
    const detail::NroHeader* pHeader = reinterpret_cast<const detail::NroHeader*>(ptr);
    if (!pHeader->CheckSignature())
    {
        return ResultInvalidNroImage();
    }

    NN_SDK_ASSERT_EQUAL((pHeader->GetSize() & (os::MemoryPageSize - 1)), 0u);
    NN_SDK_ASSERT_EQUAL((pHeader->GetBssSize() & (os::MemoryPageSize - 1)), 0u);

    *pOutSize = pHeader->GetBssSize();
    return ResultSuccess();
}

Result  RegisterModuleInfo(RegistrationInfo* pOutInfo, const void* pImage) NN_NOEXCEPT
{
    NN_SDK_ASSERT_NOT_NULL(pOutInfo);
    NN_SDK_ASSERT_NOT_NULL(pImage);

    const detail::NrrHeader* pHeader = reinterpret_cast<const detail::NrrHeader*>(pImage);
    if (!pHeader->CheckSignature())
    {
        return ResultInvalidNrrImage();
    }

    uintptr_t imageAddress = reinterpret_cast<uintptr_t>(pImage);
    size_t imageSize = pHeader->GetSize();

    NN_SDK_ASSERT_EQUAL((imageAddress & (os::MemoryPageSize - 1)), 0u);

    auto result = g_RefRoInterface->RegisterModuleInfo(0, imageAddress, imageSize);

    if (!ResultRoError::Includes(result))
    {
        NN_ABORT_UNLESS_RESULT_SUCCESS(result);
    }
    if (result.IsFailure())
    {
        return result;
    }

    pOutInfo->_state = RegistrationInfo::State_Registered;
    pOutInfo->_fileAddress = imageAddress;
    new (&pOutInfo->_listNode) util::IntrusiveListNode;

    {
        std::lock_guard<StaticMutex> lock(g_RegistrationListLock);
        g_RegistrationList.push_back(*pOutInfo);
    }

    return ResultSuccess();
}

void    UnregisterModuleInfo(RegistrationInfo* pInfo) NN_NOEXCEPT
{
    NN_SDK_ASSERT_NOT_NULL(pInfo);
    NN_SDK_ASSERT_EQUAL(pInfo->_state, RegistrationInfo::State_Registered);

    {
        std::lock_guard<StaticMutex> lock(g_RegistrationListLock);
        g_RegistrationList.erase(g_RegistrationList.iterator_to(*pInfo));
    }

    auto result = g_RefRoInterface->UnregisterModuleInfo(0, pInfo->_fileAddress);

    if (!ResultRoError::Includes(result))
    {
        NN_ABORT_UNLESS_RESULT_SUCCESS(result);
    }
    NN_ABORT_UNLESS_RESULT_SUCCESS(result);

    pInfo->_state = RegistrationInfo::State_Unregistered;
}

Result  LoadModule(
        Module* pOutModule, const void* pImage,
        void* buffer, size_t bufferSize, int flag) NN_NOEXCEPT
{
    NN_SDK_ASSERT_NOT_NULL(pOutModule);
    NN_SDK_ASSERT_NOT_NULL(pImage);
    CheckBindFlag(flag);

    uintptr_t fileAddress = reinterpret_cast<uintptr_t>(pImage);
    NN_SDK_ASSERT_EQUAL((fileAddress & (os::MemoryPageSize - 1)), 0u);

    uintptr_t bufferAddress = reinterpret_cast<uintptr_t>(buffer);
    NN_SDK_ASSERT_EQUAL((bufferAddress & (os::MemoryPageSize - 1)), 0u);
    NN_SDK_ASSERT_EQUAL((bufferSize & (os::MemoryPageSize - 1)), 0u);

    const detail::NroHeader* pHeader = reinterpret_cast<const detail::NroHeader*>(fileAddress);

    if (!pHeader->CheckSignature())
    {
        return ResultInvalidNroImage();
    }

    size_t imageSize        = pHeader->GetSize();
    size_t moduleSize       = imageSize + pHeader->GetBssSize();
    NN_SDK_ASSERT_LESS_EQUAL(moduleSize, imageSize + bufferSize);

    // Loader 側の処理
    uint64_t outAddress = 0;
    nn::svc::Break(static_cast<nn::svc::BreakReason>(nn::svc::BreakReason_PreLoadDll|nn::svc::BreakReason_NotificationOnlyFlag), fileAddress, sizeof(nn::ro::detail::NroHeader));
    auto result = g_RefRoInterface->MapManualLoadModuleMemory(
            &outAddress,
            0,
            fileAddress, imageSize,
            bufferAddress, bufferSize);
    nn::svc::Break(static_cast<nn::svc::BreakReason>(nn::svc::BreakReason_PostLoadDll|nn::svc::BreakReason_NotificationOnlyFlag), outAddress, sizeof(nn::ro::detail::NroHeader));

    if (!ResultRoError::Includes(result))
    {
        NN_ABORT_UNLESS_RESULT_SUCCESS(result);
    }
    if (result.IsFailure())
    {
        return result;
    }

    NN_ABORT_UNLESS_LESS(outAddress, UINTPTR_MAX);
    uintptr_t baseAddress = static_cast<uintptr_t>(outAddress);

    pHeader = reinterpret_cast<const detail::NroHeader*>(baseAddress);

    const rocrt::ModuleHeaderLocation* pLocation = pHeader->GetRocrtModuleHeaderLocation();
    const rocrt::ModuleHeader* pModuleHeader = rocrt::GetModuleHeader(pLocation);
    NN_ABORT_UNLESS(rocrt::CheckModuleHeaderSignature(pModuleHeader));

    uintptr_t bssBegin = rocrt::GetBssBeginAddress(pModuleHeader, pLocation);
    uintptr_t bssEnd   = rocrt::GetBssEndAddress(pModuleHeader, pLocation);

    if (bssEnd - bssBegin > 0)
    {
        std::memset(reinterpret_cast<void*>(bssBegin), 0, bssEnd - bssBegin);
    }

    uintptr_t moduleOffset = rocrt::GetModuleOffset(pModuleHeader, pLocation);
    detail::RoModule* pModule = new(reinterpret_cast<void*>(moduleOffset)) detail::RoModule;
    pOutModule->_module = pModule;

    const detail::Elf::Dyn* pDyn =
        reinterpret_cast<const detail::Elf::Dyn*>(rocrt::GetDynamicOffset(pModuleHeader, pLocation));

    pModule->Initialize(baseAddress, moduleSize, pDyn);

    {
        std::lock_guard<detail::StaticReaderWriterLock> lock(detail::g_RoLock);

        detail::g_pManualLoadList->push_back(*pModule);

        pModule->Relocation((flag & BindFlag_Lazy) == BindFlag_Lazy);
        pModule->CallInit();

        for (detail::RoModuleList::iterator it = detail::g_pAutoLoadList->begin(); it != detail::g_pAutoLoadList->end(); it++)
        {
            it->BindVariables(pModule);
        }

        for (detail::RoModuleList::iterator it = detail::g_pManualLoadList->begin(); std::next(it, 1) != detail::g_pManualLoadList->end(); it++)
        {
            it->BindVariables(pModule);
        }
    }

    pOutModule->_fileAddress = fileAddress;
    pOutModule->_bufferAddress = bufferAddress;
    pOutModule->_state = Module::State_Loaded;
    return ResultSuccess();
}

Result  LookupSymbol(uintptr_t* pOutAddress, const char* name) NN_NOEXCEPT
{
    NN_SDK_ASSERT_NOT_NULL(pOutAddress);
    NN_SDK_ASSERT_NOT_NULL(name);

    for (detail::RoModuleList::const_iterator it = detail::g_pAutoLoadList->cbegin(); it != detail::g_pAutoLoadList->cend(); it++)
    {
        detail::Elf::Sym* pSym = it->Lookup(name);
        if (pSym && pSym->GetBind() != detail::Elf::STB_LOCAL)
        {
            *pOutAddress = it->GetBase() + pSym->GetValue();
            return ResultSuccess();
        }
    }

    {
        std::shared_lock<detail::StaticReaderWriterLock> readerLock(detail::g_RoLock);

        Result result = ResultNotFound();
        for (detail::RoModuleList::const_iterator it = detail::g_pManualLoadList->cbegin(); it != detail::g_pManualLoadList->cend(); it++)
        {
            detail::Elf::Sym* pSym = it->Lookup(name);
            if (pSym && pSym->GetBind() != detail::Elf::STB_LOCAL)
            {
                *pOutAddress = it->GetBase() + pSym->GetValue();
                result = ResultSuccess();
                break;
            }
        }

        return result;
    }
}

Result  LookupModuleSymbol(uintptr_t* pOutAddress, const Module* pModule, const char* name) NN_NOEXCEPT
{
    NN_SDK_ASSERT_NOT_NULL(pOutAddress);
    NN_SDK_ASSERT_NOT_NULL(pModule);
    NN_SDK_ASSERT_NOT_NULL(name);
    NN_SDK_ASSERT_EQUAL(pModule->_state, Module::State_Loaded);

    detail::Elf::Sym* pSym = pModule->_module->Lookup(name);
    if (pSym == nullptr)
    {
        return ResultNotFound();
    }

    *pOutAddress = pModule->_module->GetBase() + pSym->GetValue();

    return ResultSuccess();
}

void    UnloadModule(Module* pModule) NN_NOEXCEPT
{
    NN_SDK_ASSERT_NOT_NULL(pModule);
    NN_SDK_ASSERT_EQUAL(pModule->_state, Module::State_Loaded);

    // アンロード対象のモジュールにデストラクタが呼ばれていない thread_local がないかチェックする
#if defined(NN_BUILD_CONFIG_COMPILER_CLANG)
    const detail::NroHeader* pNroHeader = reinterpret_cast<const detail::NroHeader*>(pModule->_module->GetBase());
    void* dsoHandle = reinterpret_cast<void*>(pModule->_module->GetBase() + pNroHeader->GetDsoHandleOffset());
    if (dsoHandle != 0)
    {
        int64_t destructorRemainingNumber = __nnmusl_get_number_uncalled_tls_dtors(dsoHandle);
        NN_ABORT_UNLESS_EQUAL(destructorRemainingNumber, 0);
    }
#endif
    {
        std::lock_guard<detail::StaticReaderWriterLock> lock(detail::g_RoLock);

        pModule->_module->CallFini();

        detail::g_pManualLoadList->erase(detail::g_pManualLoadList->iterator_to(*(pModule->_module)));

        for (detail::RoModuleList::iterator it = detail::g_pAutoLoadList->begin(); it != detail::g_pAutoLoadList->end(); it++)
        {
            detail::RoModule* module = &(*it);
            if( module->GetSoname() != nullptr )
            {
                if( std::strcmp("nnSdk.nss", module->GetSoname()) == 0 )
                {
                    continue;
                }
            }
            module->RevertReference(pModule->_module);
        }
        for (detail::RoModuleList::iterator it = detail::g_pManualLoadList->begin(); it != detail::g_pManualLoadList->end(); it++)
        {
            detail::RoModule* module = &(*it);
            module->RevertReference(pModule->_module);
        }
    }

    uintptr_t fileAddress = pModule->_fileAddress;
    uintptr_t baseAddress = pModule->_module->GetBase();
    nn::svc::Break(static_cast<nn::svc::BreakReason>(nn::svc::BreakReason_PreUnloadDll|nn::svc::BreakReason_NotificationOnlyFlag), baseAddress, sizeof(nn::ro::detail::NroHeader));
    auto result = g_RefRoInterface->UnmapManualLoadModuleMemory(0, pModule->_module->GetBase());
    nn::svc::Break(static_cast<nn::svc::BreakReason>(nn::svc::BreakReason_PostUnloadDll|nn::svc::BreakReason_NotificationOnlyFlag), fileAddress, sizeof(nn::ro::detail::NroHeader));
    NN_ABORT_UNLESS_RESULT_SUCCESS(result);

    pModule->_state = Module::State_Unloaded;
}

void InvokeTlsDestructorOfCurrentThread(Module* pModule) NN_NOEXCEPT
{
#if defined(NN_BUILD_CONFIG_COMPILER_CLANG)
    const detail::NroHeader* pNroHeader = reinterpret_cast<const detail::NroHeader*>(pModule->_module->GetBase());
    void* dsoHandle = reinterpret_cast<void*>(pModule->_module->GetBase() + pNroHeader->GetDsoHandleOffset());
    if (dsoHandle != 0)
    {
        __nnmusl_call_tls_dtors_for_module(dsoHandle);
    }
#endif
}

}} // namespace nn::ro

