﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/ro.h>
#include <nn/ro/ro_Types.h>
#include <nn/mem/mem_StandardAllocator.h>

#include <nn/sf/sf_HipcClient.h>
#include <nn/sf/sf_ExpHeapAllocator.h>

#include <nn/ro/detail/ro_RoInterface.h>
#include <nn/ro/detail/ro_PortName.h>
#include <nn/ro/ro_Result.h>
#include <nn/ldr/ldr_Result.h>
#include <nn/ro/detail/ro_RoModule.h>

#include <nn/os/os_NativeHandle.h>

#include <nn/spl/spl_Api.h>
#include <nn/settings/fwdbg/settings_SettingsGetterApi.h>
#include <nn/settings/fwdbg/settings_SettingsSetterApi.h>

#include "util_Common.h"
#include "bin_Definitions.h"

namespace {
    const char* BinVersionScriptGlobal1Name = "nro/BinVersionScriptGlobal1.nro";
    const char* BinVersionScriptGlobal2Name = "nro/BinVersionScriptGlobal2.nro";
    const char* BinAccessExternName = "nro/BinAccessExtern.nro";
    const char* InvalidNrrName = "InvalidNrr.nrr";

    typedef int (*ReturnInt)();
    typedef TestClassInt& (*ReturnClass)();
    typedef int* (*ReturnPtr)();
    const char* GetGlobalIntName = "_Z12GetGlobalIntv";
    const char* GetGlobalClassName = "_Z14GetGlobalClassv";
    const char* GetGlobalPtrName = "_Z12GetGlobalPtrv";
    const char* GlobalIntName = "g_GlobalInt";

    const int TestNum = 100;



    struct AllocatorTag
    {
    };

    typedef nn::sf::ExpHeapStaticAllocator<1024, AllocatorTag> MyAllocator;

    class AllocatorInitializer
    {
    public:
        AllocatorInitializer()
        {
            MyAllocator::Initialize(nn::lmem::CreationOption_NoOption);
        }
    } g_AllocatorInitializer;


    class LoadSequenceTest : public ::testing::TestWithParam<nn::ro::BindFlag>
    {
    protected:
        virtual void SetUp()
        {
            m_Allocator = &TestAllocator::GetInstance();
            m_NroGlobal1.SetUp(BinVersionScriptGlobal1Name, m_Allocator->GetAllocator());
            m_NroGlobal2.SetUp(BinVersionScriptGlobal1Name, m_Allocator->GetAllocator());
            m_NroGlobal3.SetUp(BinVersionScriptGlobal2Name, m_Allocator->GetAllocator());
            m_NroNotDefined.SetUp(BinAccessExternName, m_Allocator->GetAllocator());
        }

        virtual void TearDown()
        {
        }

        TestAllocator* m_Allocator;
        TestNro m_NroGlobal1;
        TestNro m_NroGlobal2;
        TestNro m_NroGlobal3;
        TestNro m_NroNotDefined;
    };

    class AutoInitRo
    {
    public:
        AutoInitRo()
        {
            nn::ro::Initialize();
        }
        ~AutoInitRo()
        {
            nn::ro::Finalize();
        }
    };

    void CheckSuccessLoaded(const TestNro& nro)
    {
        uintptr_t addr;

        /*
           1. 関数の実行結果がGlobal 関数の結果になっている
           2. LookupSymbol の結果が Global 関数を指している
           3. LookupModuleSymbol でそのモジュールのGlobal シンボルにアクセスできる
         */

        // 1
        auto result = nn::ro::LookupSymbol(&addr, GetGlobalIntName);
        ASSERT_RESULT_SUCCESS(result);
        ReturnInt getGlobalInt = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(getGlobalInt(), 1);

        // 2, 3
        result = nro.FindSymbol(&addr, GetGlobalIntName);
        ASSERT_RESULT_SUCCESS(result);
        ASSERT_EQ(addr, reinterpret_cast<uintptr_t>(getGlobalInt));

        // 1
        result = nn::ro::LookupSymbol(&addr, GetGlobalClassName);
        ASSERT_RESULT_SUCCESS(result);
        ReturnClass getGlobalClass = reinterpret_cast<ReturnClass>(addr);
        TestClassInt& testClass = getGlobalClass();
        ASSERT_EQ(testClass.Get(), 1);
        ASSERT_EQ(testClass.Calc(), 2);

        // 2, 3
        result = nro.FindSymbol(&addr, GetGlobalClassName);
        ASSERT_RESULT_SUCCESS(result);
        ASSERT_EQ(addr, reinterpret_cast<uintptr_t>(getGlobalClass));

        // 1
        result = nn::ro::LookupSymbol(&addr, GetGlobalPtrName);
        ASSERT_RESULT_SUCCESS(result);
        ReturnPtr getGlobalPtr = reinterpret_cast<ReturnPtr>(addr);

        result = nn::ro::LookupSymbol(&addr, GlobalIntName);
        ASSERT_RESULT_SUCCESS(result);
        ASSERT_EQ(addr, reinterpret_cast<uintptr_t>(getGlobalPtr()));

        // 2, 3
        result = nro.FindSymbol(&addr, GetGlobalPtrName);
        ASSERT_RESULT_SUCCESS(result);
        ASSERT_EQ(addr, reinterpret_cast<uintptr_t>(getGlobalPtr));
    }

    nn::sf::SharedPointer<nn::ro::detail::IRoInterface> MakeInterface()
    {
        nn::sf::SharedPointer<nn::ro::detail::IRoInterface> interface;
        auto result = nn::sf::CreateHipcProxyByName<nn::ro::detail::IRoInterface, MyAllocator::Policy>(
                &interface, nn::ro::detail::PortNameForRo);
        NN_ASSERT_RESULT_SUCCESS(result);

        return interface;
    }

    nn::Result ConvertResult(nn::Result ldrResult) NN_NOEXCEPT
    {
        if (ldrResult <= nn::ldr::ResultOutOfAddressSpace())
        {
            return nn::ro::ResultOutOfAddressSpace();
        }
        else if (ldrResult <= nn::ldr::ResultInvalidNroImage())
        {
            return nn::ro::ResultInvalidNroImage();
        }
        else if (ldrResult <= nn::ldr::ResultInvalidNrrImage())
        {
            return nn::ro::ResultInvalidNrrImage();
        }
        else if (ldrResult <= nn::ldr::ResultNotAuthorized())
        {
            return nn::ro::ResultNotAuthorized();
        }
        else if (ldrResult <= nn::ldr::ResultMaxModule())
        {
            return nn::ro::ResultMaxModule();
        }
        else if (ldrResult <= nn::ldr::ResultMaxRegistration())
        {
            return nn::ro::ResultMaxRegistration();
        }
        else if (ldrResult <= nn::ldr::ResultNroAlreadyLoaded())
        {
            return nn::ro::ResultNroAlreadyLoaded();
        }
        else
        {
            return ldrResult;
        }
    }

} // namespace

INSTANTIATE_TEST_CASE_P(ManualDll, LoadSequenceTest, ::testing::Values(nn::ro::BindFlag_Lazy, nn::ro::BindFlag_Now));

// 一度ロード・アンロードした DLL を再度ロードすることが出来る
TEST_P(LoadSequenceTest, LoadTwice)
{
    AutoInitRo ro;
    for (int i = 0; i < TestNum; i++)
    {
        auto result = m_NroGlobal1.Load(GetParam());
        ASSERT_RESULT_SUCCESS(result);

        CheckSuccessLoaded(m_NroGlobal1);

        m_NroGlobal1.Unload();

        result = m_NroGlobal1.Load(GetParam());
        ASSERT_RESULT_SUCCESS(result);

        CheckSuccessLoaded(m_NroGlobal1);

        m_NroGlobal1.Unload();
    }
}

// 既にロードしてある DLL を再ロードすることは出来ない
TEST_P(LoadSequenceTest, AlreadyLoaded)
{
    AutoInitRo ro;
    auto result = m_NroGlobal1.Load(GetParam());
    ASSERT_RESULT_SUCCESS(result);

    CheckSuccessLoaded(m_NroGlobal1);

    result = m_NroGlobal2.Load(GetParam());
    ASSERT_RESULT_FAILURE_VALUE(result, nn::ro::ResultNroAlreadyLoaded());

    CheckSuccessLoaded(m_NroGlobal1);

    m_NroGlobal1.Unload();
    // Nrr ファイルの破棄のため
    m_NroGlobal2.Unload();
}

TEST_P(LoadSequenceTest, NrrTest)
{
    AutoInitRo ro;
    TestFsManager& fs = TestFsManager::GetInstance();
    nn::mem::StandardAllocator* allocator = m_Allocator->GetAllocator();

    void* pNro = allocator->Allocate(MaxFileSize, DefaultAlign);
    ASSERT_TRUE(pNro != nullptr);
    fs.ReadAll(pNro, MaxFileSize, BinVersionScriptGlobal1Name);

    void* pBuffer = nullptr;
    size_t bufferSize = 0;

    auto result = nn::ro::GetBufferSize(&bufferSize, pNro);
    NN_ASSERT_RESULT_SUCCESS(result);

    if (bufferSize != 0)
    {
        pBuffer = allocator->Allocate(bufferSize, DefaultAlign);
        ASSERT_TRUE(pBuffer != nullptr);
    }

    // nrr を登録していない DLL をロードすることは出来ない
    nn::ro::Module module;
    result = nn::ro::LoadModule(&module, pNro, pBuffer, bufferSize, GetParam());
    ASSERT_RESULT_FAILURE_VALUE(result, nn::ro::ResultNotAuthorized());

    void* pNrr;
    pNrr = allocator->Allocate(MaxFileSize, DefaultAlign);
    ASSERT_TRUE(pNrr != nullptr);

    fs.ReadAll(pNrr, MaxFileSize, NrrName);

    nn::ro::RegistrationInfo info;
    result = nn::ro::RegisterModuleInfo(&info, pNrr);
    ASSERT_RESULT_SUCCESS(result);

    result = nn::ro::LoadModule(&module, pNro, pBuffer, bufferSize, GetParam());
    ASSERT_RESULT_SUCCESS(result);

    nn::ro::UnloadModule(&module);
    nn::ro::UnregisterModuleInfo(&info);
}

TEST_P(LoadSequenceTest, InvalidNrrTest)
{
    // .nrr ディレクトリにない nrr ファイルは署名が付いていないため、通常は登録できないことを確認する
    AutoInitRo ro;
    TestFsManager& fs = TestFsManager::GetInstance();
    nn::mem::StandardAllocator* allocator = m_Allocator->GetAllocator();

    void* pNrr = allocator->Allocate(MaxFileSize, DefaultAlign);
    ASSERT_TRUE(pNrr != nullptr);

    fs.ReadAll(pNrr, MaxFileSize, InvalidNrrName);

    // fwdbg の ease nro restriction が true/false の場合に意図した挙動をするかのテスト
    // ConfigItem_IsDevelopmentFunctionEnabled が false の場合は、nrr の署名検証のエラーは有効化のまま
    nn::spl::Initialize();
    bool isDevelopmentFunctionEnabled = nn::spl::GetConfigBool(nn::spl::ConfigItem_IsDevelopmentFunctionEnabled);
    nn::spl::Finalize();

    bool currentEaseNroRestriction;
    const size_t readEaseNroRestrictionSize = nn::settings::fwdbg::GetSettingsItemValue(&currentEaseNroRestriction, sizeof(currentEaseNroRestriction), "ro", "ease_nro_restriction");
    ASSERT_TRUE(readEaseNroRestrictionSize == sizeof(currentEaseNroRestriction));
    const bool nrrVerificationFail = false;
    nn::ro::RegistrationInfo info;

    // (1) fwdbg の ease nro restriction が true の場合に署名検証のエラーが無効化されていることを確認する
    bool easeNroRestrictionForTest = true;
    nn::settings::fwdbg::SetSettingsItemValue("ro", "ease_nro_restriction", &easeNroRestrictionForTest, sizeof(easeNroRestrictionForTest));
    auto result = nn::ro::RegisterModuleInfo(&info, pNrr);
    nn::settings::fwdbg::SetSettingsItemValue("ro", "ease_nro_restriction", &currentEaseNroRestriction, sizeof(currentEaseNroRestriction));
    if (result.IsFailure())
    {
        // nn::spl::ConfigItem_IsDevelopmentFunctionEnabled が false の場合は、この機能が働かないことが正常。
        ASSERT_TRUE(!isDevelopmentFunctionEnabled);
    }
    else
    {
        // 署名以外の部分は正常であることを確認するため、 nro を読み込めるか確認する。
        result = m_NroGlobal1.Load(GetParam());
        ASSERT_RESULT_SUCCESS(result);

        m_NroGlobal1.Unload();

        nn::ro::UnregisterModuleInfo(&info);
    }

    // (2) fwdbg の ease nro restriction が false の場合に署名検証のエラーが発生することを確認する
    easeNroRestrictionForTest = false;
    nn::settings::fwdbg::SetSettingsItemValue("ro", "ease_nro_restriction", &easeNroRestrictionForTest, sizeof(easeNroRestrictionForTest));
    result = nn::ro::RegisterModuleInfo(&info, pNrr);
    nn::settings::fwdbg::SetSettingsItemValue("ro", "ease_nro_restriction", &currentEaseNroRestriction, sizeof(currentEaseNroRestriction));
    if (result.IsFailure())
    {
        ASSERT_RESULT_FAILURE_VALUE(result, nn::ro::ResultNotAuthorized());
        allocator->Free(pNrr);
    }
    else
    {
        // 署名以外の部分は正常であることを確認するため、 nro を読み込めるか確認する。
        result = m_NroGlobal1.Load(GetParam());
        ASSERT_RESULT_SUCCESS(result);

        m_NroGlobal1.Unload();

        nn::ro::UnregisterModuleInfo(&info);

        allocator->Free(pNrr);

        // fwdbg の ease nro restriction が false にも関わらず署名検証をパスしてしまうことを検知するため、
        // 常にこのパスは失敗させる
        ASSERT_TRUE(nrrVerificationFail);
    }
}

TEST_P(LoadSequenceTest, InitializeTest)
{
    TestFsManager& fs = TestFsManager::GetInstance();
    nn::mem::StandardAllocator* allocator = m_Allocator->GetAllocator();

    void* pNrr;
    pNrr = allocator->Allocate(MaxFileSize, DefaultAlign);
    ASSERT_TRUE(pNrr != nullptr);
    fs.ReadAll(pNrr, MaxFileSize, NrrName);

    void* pNro = allocator->Allocate(MaxFileSize, DefaultAlign);
    ASSERT_TRUE(pNro != nullptr);
    size_t nroSize = fs.ReadAll(pNro, MaxFileSize, BinVersionScriptGlobal1Name);

    void* pBuffer = nullptr;
    size_t bufferSize = 0;
    auto result = nn::ro::GetBufferSize(&bufferSize, pNro);
    NN_ASSERT_RESULT_SUCCESS(result);

    if (bufferSize != 0)
    {
        pBuffer = allocator->Allocate(bufferSize, DefaultAlign);
        ASSERT_TRUE(pBuffer != nullptr);
    }

    {
        AutoInitRo ro;

        nn::ro::RegistrationInfo info;
        result = nn::ro::RegisterModuleInfo(&info, pNrr);
        ASSERT_RESULT_SUCCESS(result);

        nn::ro::Module module;
        result = nn::ro::LoadModule(&module, pNro, pBuffer, bufferSize, GetParam());
        ASSERT_RESULT_SUCCESS(result);
    }

    // Finalize 後に Initialize して利用することが出来る
    {
        AutoInitRo ro;

        nn::ro::RegistrationInfo info;
        result = nn::ro::RegisterModuleInfo(&info, pNrr);
        ASSERT_RESULT_SUCCESS(result);

        nroSize = fs.ReadAll(pNro, MaxFileSize, BinVersionScriptGlobal1Name);
        nn::ro::Module module;
        result = nn::ro::LoadModule(&module, pNro, pBuffer, bufferSize, GetParam());
        ASSERT_RESULT_SUCCESS(result);
    }
}

TEST_P(LoadSequenceTest, SessionTest)
{
    AutoInitRo ro;

    auto result = m_NroGlobal1.Load(GetParam());
    ASSERT_RESULT_SUCCESS(result);

    auto interface = MakeInterface();
    nn::ro::Module* pModule = m_NroGlobal1.GetModule();

    auto currentProcess = static_cast<nn::os::NativeHandle>(nn::svc::PSEUDO_HANDLE_CURRENT_PROCESS.value);
    result = interface->RegisterProcessHandle(0, nn::sf::NativeHandle(currentProcess, false));
    if (result.IsSuccess())
    {
        result = ConvertResult(interface->UnmapManualLoadModuleMemory(
                    0, pModule->_module->GetBase()));
        NN_ABORT_UNLESS(result.IsFailure());
        NN_ABORT_UNLESS(!nn::ro::ResultRoError::Includes(result));
    }

    interface.Reset();
    interface = nullptr;

    CheckSuccessLoaded(m_NroGlobal1);

    m_NroGlobal1.Unload();
}

TEST_P(LoadSequenceTest, NotDefined)
{
    AutoInitRo ro;

    // ロードできることの確認
    auto result = m_NroNotDefined.Load(GetParam());
    ASSERT_RESULT_SUCCESS(result);

    // アクセスは出来ないので、特にロードした DLL に対してはチェックをしない

    m_NroNotDefined.Unload();
}
