﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <stack>
#include <memory>
#include <nnt/nntest.h>
#include <nn/nn_Log.h>
#include <nn/crypto/crypto_Md5Generator.h>
#include <nn/diag/diag_Backtrace.h>
#include <nn/diag/diag_Module.h>
#include <nn/diag/diag_Symbol.h>
#include <nn/util/util_ScopeExit.h>
#include <nn/util/util_StringUtil.h>

namespace nn { namespace diag { namespace detail {
    uintptr_t GetModulePathImpl(char* outPathBuffer, size_t pathBufferSize, uintptr_t address) NN_NOEXCEPT;
}}}

namespace {
    void PrintBacktraceItemWithModulePath(uintptr_t address)
    {
        char modulePath[128];
        auto moduleAddress = nn::diag::detail::GetModulePathImpl(modulePath, sizeof(modulePath), address);
        if (moduleAddress != 0u)
        {
            NN_LOG("  0x%P %s+0x%X\n", address, modulePath, address - moduleAddress);
        }
        else
        {
            NN_LOG("  0x%P (unknown)\n", address);
        }
    }

    void PrintBacktraceItemWithSymbolName(uintptr_t address)
    {
        // noreturn 関数の場合は、関数末尾の bl 命令の直後に、次の関数の先頭が来るため、
        // アドレスを 1 引いて、次の関数のシンボルが検索されないようにする。
        char symbolName[128];
        const auto symbolAddress = nn::diag::GetSymbolName(symbolName, sizeof(symbolName), address - 1);
        if (symbolAddress != 0u)
        {
            // noreturn 関数の場合は、関数末尾の bl 命令の直後に、次の関数の先頭が来るため、
            // 次の関数の先頭アドレスまでは関数内とみなす。
            const auto symbolSize = nn::diag::GetSymbolSize(symbolAddress);
            const auto isOutOfFunction = address > symbolAddress + symbolSize;
            NN_UNUSED(isOutOfFunction);
            NN_LOG("  0x%P %s+0x%X%s\n", address, symbolName, address - symbolAddress, isOutOfFunction ? " (too far)" : "");
        }
        else
        {
            NN_LOG("  0x%P (unknown)\n", address);
        }
    }

    void PrintBacktrace(uintptr_t traces[], int traceCount)
    {
        NN_LOG("User Backtrace with Symbol\n");
        for (int i = 0; i < traceCount; i++)
        {
            PrintBacktraceItemWithSymbolName(traces[i]);
        }
        NN_LOG("\n");

        NN_LOG("User Backtrace with Module\n");
        for (int i = 0; i < traceCount; i++)
        {
            PrintBacktraceItemWithModulePath(traces[i]);
        }
        NN_LOG("\n");
    }
}

// モジュールパスが取れることを確認するテスト
// モジュールパスが正しいかはテストしない
TEST(GetModuleTest, GetModule)
{
    const int traceCountMax = 32;
    uintptr_t traces[traceCountMax];
    int traceCount = nn::diag::GetBacktrace(traces, traceCountMax);
    PrintBacktrace(traces, traceCount);
    EXPECT_GT(traceCount, 0);
}

TEST(GetModuleTest, GetAllModuleInfo)
{
    const auto bufferSize = nn::diag::GetRequiredBufferSizeForGetAllModuleInfo();

    auto buffer = reinterpret_cast<nn::Bit8*>(malloc(bufferSize));
    NN_UTIL_SCOPE_EXIT { free(buffer); };

    nn::diag::ModuleInfo* modules;
    const auto moduleCount = nn::diag::GetAllModuleInfo(&modules, buffer, bufferSize);

    NN_LOG("Modules:\n");
    NN_LOG("  %-*s   %-*s   path\n", sizeof(uintptr_t) * 2, "base", sizeof(uintptr_t) * 2, "size");
    for (auto i = 0; i < moduleCount; i++)
    {
        const auto& module = modules[i];
        NN_LOG("  0x%P 0x%P %s\n", module.baseAddress, module.size, module.path);
    }
    NN_LOG("\n");

#if defined(NN_BUILD_CONFIG_OS_HORIZON) // Windows 版テストは、ひとまずログ出力だけ。

    const int expectedModuleCount = 3; // nnrtld, testDiag_Module, nnSdk(Jp/En) の 3 つ。
    EXPECT_EQ(expectedModuleCount, moduleCount);

    const int expectedBufferSize =
        sizeof(nn::diag::ModuleInfo) * expectedModuleCount
            + sizeof("nnrtld")
            + sizeof("testDiag_Module")
#if defined(NN_SDK_BUILD_DEBUG) || defined(NN_SDK_BUILD_DEVELOP)
            + sizeof("nnSdkXX"); // XX は Jp/En。
#elif defined(NN_SDK_BUILD_RELEASE)
            + sizeof("nnSdk");
#else
    #error ”未サポートのビルドタイプです。”
#endif
    EXPECT_EQ(expectedBufferSize, bufferSize);

    // rtld
    {
        const auto& module = modules[0];
        EXPECT_GE(reinterpret_cast<uintptr_t>(&module), reinterpret_cast<uintptr_t>(buffer));
        EXPECT_LT(reinterpret_cast<uintptr_t>(&module), reinterpret_cast<uintptr_t>(buffer) + bufferSize);

        EXPECT_STREQ(module.path, "nnrtld");
        EXPECT_GE(reinterpret_cast<uintptr_t>(&(module.path)), reinterpret_cast<uintptr_t>(buffer));
        EXPECT_LT(reinterpret_cast<uintptr_t>(&(module.path)) + sizeof("nnrtld"), reinterpret_cast<uintptr_t>(buffer) + bufferSize);
    }

    // testDiag_Module
    {
        const auto& module = modules[1];
        EXPECT_GE(reinterpret_cast<uintptr_t>(&module), reinterpret_cast<uintptr_t>(buffer));
        EXPECT_LT(reinterpret_cast<uintptr_t>(&module), reinterpret_cast<uintptr_t>(buffer) + bufferSize);

        EXPECT_STREQ(module.path, "testDiag_Module");
        EXPECT_GE(reinterpret_cast<uintptr_t>(&(module.path)), reinterpret_cast<uintptr_t>(buffer));
        EXPECT_LT(reinterpret_cast<uintptr_t>(&(module.path)) + sizeof("testDiag_Module"), reinterpret_cast<uintptr_t>(buffer) + bufferSize);
    }

    // nnSdk(Jp/En)
    {
        const auto& module = modules[2];
        EXPECT_GE(reinterpret_cast<uintptr_t>(&module), reinterpret_cast<uintptr_t>(buffer));
        EXPECT_LT(reinterpret_cast<uintptr_t>(&module), reinterpret_cast<uintptr_t>(buffer) + bufferSize);

#if defined(NN_SDK_BUILD_DEBUG) || defined(NN_SDK_BUILD_DEVELOP)
        EXPECT_TRUE(nn::util::Strncmp(module.path, "nnSdkJp", sizeof("nnSdkJp") - 1) == 0 || nn::util::Strncmp(module.path, "nnSdkEn", sizeof("nnSdkEn") - 1) == 0);
        EXPECT_GE(reinterpret_cast<uintptr_t>(&(module.path)), reinterpret_cast<uintptr_t>(buffer));
        EXPECT_LT(reinterpret_cast<uintptr_t>(&(module.path)) + sizeof("nnSdkXX"), reinterpret_cast<uintptr_t>(buffer) + bufferSize);
#elif defined(NN_SDK_BUILD_RELEASE)
        EXPECT_STREQ(module.path, "nnSdk");
        EXPECT_GE(reinterpret_cast<uintptr_t>(&(module.path)), reinterpret_cast<uintptr_t>(buffer));
        EXPECT_LT(reinterpret_cast<uintptr_t>(&(module.path)) + sizeof("nnSdk"), reinterpret_cast<uintptr_t>(buffer) + bufferSize);
#else
    #error ”未サポートのビルドタイプです。”
#endif
    }
#endif // #if NN_BUILD_CONFIG_OS_HORIZON
}

TEST(GetModuleTest, GetReadOnlyDataSectionRange)
{
    const auto bufferSize = nn::diag::GetRequiredBufferSizeForGetAllModuleInfo();

    std::unique_ptr<nn::Bit8[]> buffer(new nn::Bit8[bufferSize]);

    nn::diag::ModuleInfo* modules;
    const auto moduleCount = nn::diag::GetAllModuleInfo(&modules, buffer.get(), bufferSize);

    for (auto i = 0; i < moduleCount; i++)
    {
        const auto& module = modules[i];

        if (nn::util::Strncmp(module.path, "nnrtld", sizeof("nnrtld") - 1) == 0 ||
            nn::util::Strncmp(module.path, "nnSdk", sizeof("nnSdk") - 1) == 0)
        {
            uintptr_t startAddress, endAddress;
            ASSERT_FALSE(nn::diag::GetReadOnlyDataSectionRange(&startAddress, &endAddress, module.baseAddress));
            continue;
        }

        uintptr_t startAddress, endAddress;
        ASSERT_TRUE(nn::diag::GetReadOnlyDataSectionRange(&startAddress, &endAddress, module.baseAddress));

        // rodata 全域を読めることを確認。
        nn::crypto::Md5Generator generator;
        generator.Initialize();
        generator.Update(reinterpret_cast<void*>(startAddress), endAddress - startAddress);
    }
}
