﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/ro.h>
#include <nn/ro/ro_Types.h>

#include "util_Common.h"

#include <nn/nn_Log.h>


namespace {
    const char* BinCrossReference1 = "nro/BinCrossReference1.nro";
    const char* BinCrossReference2 = "nro/BinCrossReference2.nro";
    typedef int (*ReturnInt)();
    const char* ExternalFunctionName = "_Z16ExternalFunctionv";
    const char* ExternalFunction1Name = "_Z17ExternalFunction1v";
    const char* ExternalFunction2Name = "_Z17ExternalFunction2v";
    const char* CallExternalFunctionFrom1Name = "_Z25CallExternalFunctionFrom1v";
    const char* CallExternalFunctionFrom2Name = "_Z25CallExternalFunctionFrom2v";
    const char* CallExternalFunction1Name = "_Z21CallExternalFunction1v";
    const char* CallExternalFunction2Name = "_Z21CallExternalFunction2v";
    const char* AccessExternValue1FunctionName = "_Z20AccessExternalValue1v";
    const char* AccessExternValue2FunctionName = "_Z20AccessExternalValue2v";
    const char* AccessExternalValueFrom1FunctionName = "_Z24AccessExternalValueFrom1v";
    const char* AccessExternalValueFrom2FunctionName = "_Z24AccessExternalValueFrom2v";
    const char* ExternalValueName = "g_ExternValue";
    const char* ExternalValue1Name = "g_ExternValue1";
    const char* ExternalValue2Name = "g_ExternValue2";

    class CrossReferenceTest : public ::testing::TestWithParam<nn::ro::BindFlag>
    {
    protected:

        virtual void SetUp()
        {
            m_Allocator = &TestAllocator::GetInstance();
            m_Nro1.SetUp(BinCrossReference1, m_Allocator->GetAllocator());
            m_Nro2.SetUp(BinCrossReference2, m_Allocator->GetAllocator());
            nn::ro::Initialize();
        }

        virtual void TearDown()
        {
            nn::ro::Finalize();
        }

        TestAllocator* m_Allocator;
        TestNro m_Nro1;
        TestNro m_Nro2;
    };
} // namespace

INSTANTIATE_TEST_CASE_P(ManualDll, CrossReferenceTest, ::testing::Values(nn::ro::BindFlag_Lazy));

// 相互参照が出来る
TEST_P(CrossReferenceTest, Bind)
{
    auto result = m_Nro1.Load(GetParam());
    ASSERT_RESULT_SUCCESS(result);

    result = m_Nro2.Load(GetParam());
    ASSERT_RESULT_SUCCESS(result);

    uintptr_t addr;
    uintptr_t tmpAddr;
    ReturnInt func = nullptr;
    int* pValue = nullptr;

    // 重複シンボルはロード順によって解決される
    {
        // 関数
        result = nn::ro::LookupSymbol(&addr, ExternalFunctionName);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 1);

        result = m_Nro1.FindSymbol(&tmpAddr, ExternalFunctionName);
        ASSERT_RESULT_SUCCESS(result);
        ASSERT_EQ(addr, tmpAddr);

        result = m_Nro2.FindSymbol(&tmpAddr, ExternalFunctionName);
        ASSERT_RESULT_SUCCESS(result);
        ASSERT_NE(addr, tmpAddr);

        result = nn::ro::LookupSymbol(&addr, CallExternalFunctionFrom1Name);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 1 + 1);

        result = nn::ro::LookupSymbol(&addr, CallExternalFunctionFrom2Name);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 1 + 2);


        // 変数
        result = nn::ro::LookupSymbol(&addr, ExternalValueName);
        ASSERT_RESULT_SUCCESS(result);

        pValue = reinterpret_cast<int*>(addr);
        ASSERT_EQ(*pValue, 1);

        result = m_Nro1.FindSymbol(&tmpAddr, ExternalValueName);
        ASSERT_RESULT_SUCCESS(result);
        ASSERT_EQ(addr, tmpAddr);

        result = m_Nro2.FindSymbol(&tmpAddr, ExternalValueName);
        ASSERT_RESULT_SUCCESS(result);
        ASSERT_NE(addr, tmpAddr);

        result = nn::ro::LookupSymbol(&addr, AccessExternalValueFrom1FunctionName);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 1 + 1);

        result = nn::ro::LookupSymbol(&addr, AccessExternalValueFrom2FunctionName);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 1 + 2);
    }

    // 先にロードされたモジュールによって参照が解決されている
    {
        // 関数
        result = nn::ro::LookupSymbol(&addr, ExternalFunction1Name);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 1);

        result = nn::ro::LookupSymbol(&addr, CallExternalFunction1Name);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 1 + 1);

        // 変数
        result = nn::ro::LookupSymbol(&addr, ExternalValue1Name);
        ASSERT_RESULT_SUCCESS(result);

        int* pValue = reinterpret_cast<int*>(addr);
        ASSERT_EQ(*pValue, 1);

        result = nn::ro::LookupSymbol(&addr, AccessExternValue1FunctionName);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 1 + 1);
    }

    // 後にロードされたモジュールによって参照が解決されている
    {
        // 関数
        result = nn::ro::LookupSymbol(&addr, ExternalFunction2Name);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 2);

        result = nn::ro::LookupSymbol(&addr, CallExternalFunction2Name);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 2 + 2);

        // 変数
        result = nn::ro::LookupSymbol(&addr, ExternalValue2Name);
        ASSERT_RESULT_SUCCESS(result);

        int* pValue = reinterpret_cast<int*>(addr);
        ASSERT_EQ(*pValue, 2);

        result = nn::ro::LookupSymbol(&addr, AccessExternValue2FunctionName);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 2 + 2);
    }

    // 参照元となるモジュールをアンロードし再度ロードすると、参照が再解決される
    m_Nro1.Unload();
    m_Nro1.Load(GetParam());

    // 重複シンボルはロード順によって解決される
    {
        // 関数
        result = nn::ro::LookupSymbol(&addr, ExternalFunctionName);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 2);

        result = m_Nro1.FindSymbol(&tmpAddr, ExternalFunctionName);
        ASSERT_RESULT_SUCCESS(result);
        ASSERT_NE(addr, tmpAddr);

        result = m_Nro2.FindSymbol(&tmpAddr, ExternalFunctionName);
        ASSERT_RESULT_SUCCESS(result);
        ASSERT_EQ(addr, tmpAddr);

        result = nn::ro::LookupSymbol(&addr, CallExternalFunctionFrom1Name);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 2 + 1);

        result = nn::ro::LookupSymbol(&addr, CallExternalFunctionFrom2Name);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 2 + 2);


        // 変数
        result = nn::ro::LookupSymbol(&addr, ExternalValueName);
        ASSERT_RESULT_SUCCESS(result);

        pValue = reinterpret_cast<int*>(addr);
        ASSERT_EQ(*pValue, 2);

        result = m_Nro1.FindSymbol(&tmpAddr, ExternalValueName);
        ASSERT_RESULT_SUCCESS(result);
        ASSERT_NE(addr, tmpAddr);

        result = m_Nro2.FindSymbol(&tmpAddr, ExternalValueName);
        ASSERT_RESULT_SUCCESS(result);
        ASSERT_EQ(addr, tmpAddr);

        result = nn::ro::LookupSymbol(&addr, AccessExternalValueFrom1FunctionName);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 2 + 1);

        result = nn::ro::LookupSymbol(&addr, AccessExternalValueFrom2FunctionName);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 2 + 2);
    }

    // 先にロードされたモジュールによって参照が解決されている
    {
        // 関数
        result = nn::ro::LookupSymbol(&addr, ExternalFunction1Name);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 1);

        result = nn::ro::LookupSymbol(&addr, CallExternalFunction1Name);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 1 + 1);

        // 変数
        result = nn::ro::LookupSymbol(&addr, ExternalValue1Name);
        ASSERT_RESULT_SUCCESS(result);

        int* pValue = reinterpret_cast<int*>(addr);
        ASSERT_EQ(*pValue, 1);

        result = nn::ro::LookupSymbol(&addr, AccessExternValue1FunctionName);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 1 + 1);
    }

    // 後にロードされたモジュールによって参照が解決されている
    {
        // 関数
        result = nn::ro::LookupSymbol(&addr, ExternalFunction2Name);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 2);

        result = nn::ro::LookupSymbol(&addr, CallExternalFunction2Name);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 2 + 2);

        // 変数
        result = nn::ro::LookupSymbol(&addr, ExternalValue2Name);
        ASSERT_RESULT_SUCCESS(result);

        int* pValue = reinterpret_cast<int*>(addr);
        ASSERT_EQ(*pValue, 2);

        result = nn::ro::LookupSymbol(&addr, AccessExternValue2FunctionName);
        ASSERT_RESULT_SUCCESS(result);

        func = reinterpret_cast<ReturnInt>(addr);
        ASSERT_EQ(func(), 2 + 2);
    }

    m_Nro2.Unload();
    m_Nro1.Unload();
} // NOLINT(impl/function_size)

