﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <stack>
#include <nnt/nntest.h>
#include <nn/nn_Assert.h>
#include <nn/nn_Log.h>
#include <nn/os.h>
#include <nn/diag.h>
#include "../../../../Os/Sources/Tests/Common/test_SetjmpWithFpuContext.h"

namespace {
    void PrintBacktraceItemWithSymbol(uintptr_t address)
    {
        // noreturn 関数の場合は、関数末尾の lr 命令の直後に、次の関数の先頭が来るため、
        // アドレスを 1 引いて、次の関数のシンボルが検索されないようにする。
        static char s_NameBuffer[128];
        const auto symbolAddress = nn::diag::GetSymbolName(s_NameBuffer, sizeof(s_NameBuffer), address - 1);
        if (symbolAddress != 0u)
        {
            NN_LOG("   0x%P %s+0x%X\n", address, s_NameBuffer, address - symbolAddress);
        }
        else
        {
            NN_LOG("   0x%P (unknown)\n", address);
        }
    }

    void PrintBacktraceWithSymbol(uintptr_t addresses[], int addressCount)
    {
        NN_LOG("Stack trace:\n");
        for (int i = 0; i < addressCount; i++)
        {
            PrintBacktraceItemWithSymbol(addresses[i]);
        }
    }

#if defined(NN_BUILD_CONFIG_OS_HORIZON)
    void PrintBacktraceItemWithModule(uintptr_t address)
    {
        static char s_PathBuffer[128];
        const auto moduleAddress = nn::diag::GetModulePath(s_PathBuffer, sizeof(s_PathBuffer), address);
        if (moduleAddress != 0u)
        {
            NN_LOG("   0x%P %s+0x%X\n", address, s_PathBuffer, address - moduleAddress);
        }
        else
        {
            NN_LOG("   0x%P (unknown)\n", address);
        }
    }

    void PrintBacktraceWithModule(uintptr_t addresses[], int addressCount)
    {
        NN_LOG("Stack trace:\n");
        for (int i = 0; i < addressCount; i++)
        {
            PrintBacktraceItemWithModule(addresses[i]);
        }
    }
#endif

    void PrintBacktrace(uintptr_t addresses[], int addressCount)
    {
        PrintBacktraceWithSymbol(addresses, addressCount);
#if defined(NN_BUILD_CONFIG_OS_HORIZON)
        PrintBacktraceWithModule(addresses, addressCount);
#endif
    }
}

// バックトレースの正当性テストは、Clang の Debug ビルドのみ対応
#if defined(NN_BUILD_CONFIG_COMPILER_CLANG) && defined(NN_SDK_BUILD_DEBUG)
namespace {
    std::stack<uintptr_t> g_LrStack;

    void baz()
    {
        // テスト用にリターンアドレスを格納しておく
        g_LrStack.push(reinterpret_cast<uintptr_t>(__builtin_return_address(0)));

        const int addressCountMax = 32;
        uintptr_t addresses[addressCountMax];
        const int addressCount = nn::diag::GetBacktrace(addresses, addressCountMax);
        PrintBacktrace(addresses, addressCount);

        // バックトレースには、SDK の内部実装の関数が含まれるため、
        // インデックスを進める必要がある
        int i;
        for (i = 0; i < addressCount && addresses[i] != g_LrStack.top(); i++);
        ASSERT_LT(i, addressCount);

        while (!g_LrStack.empty())
        {
            // テスト用に格納したリターンアドレスと、
            // バックトレースで取得したアドレスが等しいかテストする
            ASSERT_EQ(addresses[i++], g_LrStack.top());
            g_LrStack.pop();
        }
    }

    void bar()
    {
        g_LrStack.push(reinterpret_cast<uintptr_t>(__builtin_return_address(0)));

        baz();
    }

    void foo()
    {
        // テスト用にリターンアドレスを格納しておく
        g_LrStack.push(reinterpret_cast<uintptr_t>(__builtin_return_address(0)));

        bar();
    }
}

// 取得したバックトレースが正しいことを確認するテスト
TEST(GetBacktraceTest, CheckBacktrace)
{
    foo();
}
#endif // defined(NN_BUILD_CONFIG_COMPILER_CLANG) && defined(NN_SDK_BUILD_DEBUG)

//-----------------------------------------------------------------------------

// バックトレースが取れることを確認するテスト
// バックトレースが正しいかはテストしない
TEST(GetBacktraceTest, GetBacktrace)
{
    const int addressCountMax = 32;
    uintptr_t addresses[addressCountMax];

    int addressCount;

    addressCount = nn::diag::GetBacktrace(addresses, addressCountMax);
    PrintBacktrace(addresses, addressCount);
    EXPECT_GT(addressCount, 0);

    addressCount = nn::diag::GetBacktrace(addresses, 1);
    PrintBacktrace(addresses, 1);
    EXPECT_EQ(addressCount, 1);
}

//-----------------------------------------------------------------------------

namespace
{
    int RecursiveFunction(int recurrenceCount)
    {
        NN_ASSERT_GREATER_EQUAL(recurrenceCount, 1);

        // テストで上限の要素数に達しないように、多目にとる。
        const int addressCountMax = 64;
        static uintptr_t addresses[addressCountMax];

        // ループ展開を阻止するために、末尾再帰にならないようにする。
        volatile const int addressCount = recurrenceCount > 1
            ? RecursiveFunction(recurrenceCount - 1)
            : nn::diag::GetBacktrace(addresses, addressCountMax);

        NN_ASSERT_LESS(addressCount, addressCountMax);

        return addressCount;
    }
}

TEST(GetBacktraceTest, AddressCount)
{
    const int addressCountMax = 32;
    static uintptr_t addresses[addressCountMax];

    const auto baseDepthCount = nn::diag::GetBacktrace(addresses, addressCountMax);
    ASSERT_LE(baseDepthCount, addressCountMax);

    for (auto recurrenceCount = 1; recurrenceCount <= 31; recurrenceCount++)
    {
        const auto depthCount = RecursiveFunction(recurrenceCount);
        EXPECT_EQ(recurrenceCount + baseDepthCount, depthCount);
    }
}

//-----------------------------------------------------------------------------

namespace {

const size_t FiberStackSize = 16 * 1024;

NN_OS_ALIGNAS_FIBER_STACK   uint8_t  g_FiberStack1[ FiberStackSize ];
NN_OS_ALIGNAS_GUARDED_STACK uint8_t  g_FiberStack2[ FiberStackSize ];

nn::os::FiberType  g_Fiber1;
nn::os::FiberType  g_Fiber2;

void InvokePrintBacktrace()
{
    const int addressCountMax = 32;
    uintptr_t addresses[addressCountMax];

    int addressCount = nn::diag::GetBacktrace(addresses, addressCountMax);
    PrintBacktrace(addresses, addressCount);
    EXPECT_GT(addressCount, 0);

    addressCount = nn::diag::GetBacktrace(addresses, 1);
    PrintBacktrace(addresses, 1);
    EXPECT_EQ(addressCount, 1);
}

nn::os::FiberType* FiberFunction1(void*)
{
    NN_LOG("FiberFunction1: resume\n");

    InvokePrintBacktrace();

    NN_LOG("FiberFunction1: suspend\n");

    nn::os::SwitchToFiber( &g_Fiber2 );

    NN_LOG("FiberFunction1: comeback\n");

    InvokePrintBacktrace();

    NN_LOG("FiberFunction1: finish\n");

    return &g_Fiber2;
}

nn::os::FiberType* FiberFunction2(void*)
{
    NN_LOG("FiberFunction2: resume\n");

    InvokePrintBacktrace();

    NN_LOG("FiberFunction2: suspend\n");

    nn::os::SwitchToFiber( &g_Fiber1 );

    NN_LOG("FiberFunction2: comeback\n");

    InvokePrintBacktrace();

    NN_LOG("FiberFunction2: finish\n");

    return nullptr;
}

}   // namespace

TEST(GetBacktraceTest, GetBacktraceOnFiber)
{
    nn::os::InitializeFiber( &g_Fiber1, FiberFunction1, nullptr, g_FiberStack1, FiberStackSize, nn::os::FiberFlag_NoStackGuard );
    nn::os::InitializeFiber( &g_Fiber2, FiberFunction2, nullptr, g_FiberStack2, FiberStackSize, 0 );

    nn::os::SwitchToFiber( &g_Fiber1 );

    nn::os::FinalizeFiber( &g_Fiber1 );
    nn::os::FinalizeFiber( &g_Fiber2 );
}

//-----------------------------------------------------------------------------

#if defined(NN_BUILD_CONFIG_OS_HORIZON)

namespace
{
    nn::os::FiberType  g_FiberForUserExceptionHandlerTest;

    NN_OS_ALIGNAS_THREAD_STACK uint8_t g_UserExceptionHandlerStack[ 16 * 1024 ];
    NN_OS_ALIGNAS_FIBER_STACK  uint8_t g_FiberStackForUserExceptionHandlerTest[ 16 * 1024 ];

    nntosJmpbufWithFpuContext g_JumpBuf;

    int g_BaseDepthCount;
    int g_NullReferenceDepthCount;

    int InvokeNullReference(int recurrenceCount)
    {
        NN_ASSERT_GREATER_EQUAL(recurrenceCount, 1);

        // ループ展開を阻止するために、末尾再帰にならないようにする。
        volatile const int result = recurrenceCount > 1
            ? InvokeNullReference(recurrenceCount - 1)
            : *reinterpret_cast<volatile int*>(NULL); // 意図的にメモリアクセス違反を発生させる。

        return result;
    }

    void HandleUserException(nn::os::UserExceptionInfo* info)
    {
        // 期待されるアドレス数。
        const int expectCountMax = g_NullReferenceDepthCount + g_BaseDepthCount;

        const int addressCountMax = 64;
        uintptr_t addresses[addressCountMax];

        // 実際のアドレス数より少ない要素数で取得したときは、指定した要素数が取得される。
        for (int expectCount = 1; expectCount < expectCountMax; expectCount++)
        {
            EXPECT_EQ(expectCount, nn::diag::GetBacktrace(addresses, expectCount, info->detail.fp, info->detail.sp, info->detail.pc));
        }

        // 実際のアドレス数と同じ要素数で取得したときは、実際のアドレス数が取得される。
        const auto actualCount = nn::diag::GetBacktrace(addresses, expectCountMax, info->detail.fp, info->detail.sp, info->detail.pc);
        EXPECT_EQ(expectCountMax, actualCount);

        // 実際のアドレス数より大きな要素数で取得したときも、実際のアドレス数が取得される。
        EXPECT_EQ(expectCountMax, nn::diag::GetBacktrace(addresses, expectCountMax + 1, info->detail.fp, info->detail.sp, info->detail.pc));

        // メモリアクセス違反を起こす寸前へ戻る。
        nntosLongjmpWithFpuContext( &g_JumpBuf );
    }

    void TestGetBacktraceOnUserExceptionHandler()
    {
        const int addressCountMax = 32;
        static uintptr_t addresses[addressCountMax];

        g_BaseDepthCount = nn::diag::GetBacktrace(addresses, addressCountMax);
        ASSERT_LE(g_BaseDepthCount, addressCountMax);

    #if defined(NN_SDK_BUILD_DEBUG)
        // Debug ビルドのときは、nn::diag::GetBacktrace() のアドレスが含まれるため、その分を引いておく。
        g_BaseDepthCount--;
    #endif

        // 1~32 の深さで例外を起こし、それぞれ適切なアドレス数が取れることをテストする。
        for (g_NullReferenceDepthCount = 1; g_NullReferenceDepthCount <= 32; g_NullReferenceDepthCount++)
        {
            if (nntosSetjmpWithFpuContext( &g_JumpBuf ) == 0)
            {
                // 指定した深さでメモリアクセス違反を発生させる。
                InvokeNullReference(g_NullReferenceDepthCount);

                // メモリアクセス違反が起こらなければテスト失敗。
                ADD_FAILURE();
            }
        }
    }
}

TEST(GetBacktraceTest, GetBacktraceOnUserExceptionHandler)
{
    // VSI デバッガ接続時にもユーザ例外ハンドラを動作させる。
    nn::os::EnableUserExceptionHandlerOnDebugging(true);

    // 例外時用スタックを使う場合。
    nn::os::SetUserExceptionHandler(HandleUserException, &g_UserExceptionHandlerStack, sizeof(g_UserExceptionHandlerStack), nn::os::UserExceptionInfoUsesHandlerStack);
    TestGetBacktraceOnUserExceptionHandler();

    // 例外発生スレッドのスタックを使う場合。
    nn::os::SetUserExceptionHandler(HandleUserException, nn::os::HandlerStackUsesThreadStack, 0, nn::os::UserExceptionInfoUsesThreadStack);
    TestGetBacktraceOnUserExceptionHandler();

    // ファイバーで例外が発生する場合。
    nn::os::InitializeFiber( &g_FiberForUserExceptionHandlerTest, [](void*) -> nn::os::FiberType*
    {
        // 例外時用スタックを使う場合。
        nn::os::SetUserExceptionHandler(HandleUserException, &g_UserExceptionHandlerStack, sizeof(g_UserExceptionHandlerStack), nn::os::UserExceptionInfoUsesHandlerStack);
        TestGetBacktraceOnUserExceptionHandler();

        // 例外発生スレッドのスタックを使う場合。
        nn::os::SetUserExceptionHandler(HandleUserException, nn::os::HandlerStackUsesThreadStack, 0, nn::os::UserExceptionInfoUsesThreadStack);
        TestGetBacktraceOnUserExceptionHandler();

        return nullptr;

    }, nullptr, g_FiberStackForUserExceptionHandlerTest, sizeof(g_FiberStackForUserExceptionHandlerTest), nn::os::FiberFlag_NoStackGuard);

    nn::os::SwitchToFiber(&g_FiberForUserExceptionHandlerTest);
}

#endif // #if defined(NN_BUILD_CONFIG_OS_HORIZON)

//-----------------------------------------------------------------------------
