﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <algorithm>
#include <new>

#include <nn/nn_Common.h>
#include <nn/os/os_Thread.h>
#include <nn/os/os_Types.h>
#include <nn/os/os_SdkThread.h>
#include <nn/os/os_SdkThreadInfo.h>
#include <nn/util/util_StringUtil.h>

#include "detail/os_ThreadManager-os.horizon.h"

#include "profiler_CodeRewriting.h"
#include "profiler_CommMessages.h"
#include "profiler_Comms.h"
#include "profiler_Defines.h"
#include "profiler_LibModule.h"
#include "profiler_Logging.h"
#include "profiler_Memory.h"
#include "profiler_Messages.h"
#include "profiler_SamplingThread.h"
#include "profiler_Svc.autogen.h"
#include "profiler_TargetApplication.h"
#include "profiler_ThreadNode.h"
#include "profiler_Workarea.h"


namespace nn { namespace profiler {

uint32_t TargetApplication::s_CoreMask = (1u << 0) | (1u << 1) | (1u << 2) | (0u << 3);
int TargetApplication::s_CoreCount = 0;
int TargetApplication::s_MaximumThreadPriority = nn::os::HighestThreadPriority;
TargetApplication* TargetApplication::s_CurrentApplication = nullptr;


namespace /*anonymous*/
{
    volatile bool sShouldExit = false;
} // anonymous



/*****************************************************************************
    Static Class Definitions
 *****************************************************************************/
void TargetApplication::Initialize()
{
    if (s_CurrentApplication != nullptr) { return; }

    s_CurrentApplication = Memory::GetInstance()->Allocate<TargetApplication>();
    memset(s_CurrentApplication, 0, sizeof(TargetApplication));
    new (s_CurrentApplication) TargetApplication();

    nn::Bit64 prioMask;
    auto result = nn::svc::profiler::GetInfo(
        &prioMask,
        nn::svc::InfoType_PriorityMask,
        nn::svc::PSEUDO_HANDLE_CURRENT_PROCESS,
        0);

    int foundPriority = nn::os::LowestThreadPriority;
    if (result.IsSuccess())
    {
        for (int i = 0; i < 64; ++i)
        {
            nn::Bit64 val = (static_cast<uint64_t>(1) << i);
            if ((prioMask & val) != 0)
            {
                foundPriority = (i - nn::os::detail::UserThreadPriorityOffset);
                break;
            }
        }
    }
    //foundPriority = std::max(foundPriority, nn::os::HighestSystemThreadPriority);
    foundPriority = std::min(foundPriority, nn::os::HighestThreadPriority);
    s_MaximumThreadPriority = foundPriority;

    sShouldExit = false;
}



void TargetApplication::Finalize()
{
    sShouldExit = true;

    if (s_CurrentApplication != nullptr)
    {
        s_CurrentApplication->~TargetApplication();
        Memory::GetInstance()->Free(s_CurrentApplication);
        s_CurrentApplication = nullptr;
    }
}



TargetApplication* TargetApplication::GetCurrent()
{
    return s_CurrentApplication;
}



uint32_t TargetApplication::GetCoreMask()
{
    return s_CoreMask;
}



void TargetApplication::SetCoreMask(uint32_t mask)
{
    s_CoreCount = nn::util::popcount(mask);
    s_CoreMask = mask;
}



int TargetApplication::GetCoreCount()
{
    return s_CoreCount;
}



int TargetApplication::GetThreadPriority(nn::os::ThreadType* thread)
{
    int priority = nn::os::DefaultThreadPriority;
    if (thread != nullptr)
    {
        // The call to nn::os::GetThreadNamePointer will ABORT unless the thread is runnable.
        // Since we don't want to crash, perform the same check out here before calling.
        if ((thread->_state == nn::os::ThreadType::State_Initialized) ||
            (thread->_state == nn::os::ThreadType::State_Started) ||
            (thread->_state == nn::os::ThreadType::State_Exited))
        {
            priority = nn::os::GetThreadCurrentPriority(thread);
        }
    }
    return priority;
}



void TargetApplication::GetThreadCoreMask(nn::os::ThreadType* thread, int* idealCore, int* coreMask)
{
    nn::Bit64 outAffinityMask = 0;
    *idealCore = nn::os::IdealCoreDontCare;
    if (thread != nullptr)
    {
        if ((thread->_state == nn::os::ThreadType::State_Initialized) ||
            (thread->_state == nn::os::ThreadType::State_Started) ||
            (thread->_state == nn::os::ThreadType::State_Exited))
        {
            nn::os::GetThreadCoreMask(idealCore, &outAffinityMask, thread);
        }
    }
    NN_SDK_ASSERT(outAffinityMask <= UINT32_MAX);
    *coreMask = static_cast<int>(outAffinityMask);
}



nn::os::ThreadId TargetApplication::GetThreadId(nn::os::ThreadType* thread)
{
    return nn::os::GetThreadId(thread);
}



nn::os::ThreadType* TargetApplication::GetThreadType(nn::os::ThreadId thread)
{
    // todo: Determine how to go from a true thread id to a thread type
    auto item = s_CurrentApplication->m_threadList.GetByKey(thread);
    if (item != nullptr) { return item->threadType; }
    else { return nullptr; }
}



const char* TargetApplication::GetThreadName(nn::os::ThreadType* thread)
{
    const char* threadName = nullptr;
    if (thread != nullptr)
    {
        // The call to nn::os::GetThreadNamePointer will ABORT unless the thread is runnable.
        // Since we don't want to crash, perform the same check out here before calling.
        if ((thread->_state == nn::os::ThreadType::State_Initialized) ||
            (thread->_state == nn::os::ThreadType::State_Started) ||
            (thread->_state == nn::os::ThreadType::State_Exited))
        {
            threadName = nn::os::GetThreadNamePointer(thread);
        }
    }
    return threadName;
}



ThreadListItem* TargetApplication::GetThreadData(nn::os::ThreadType* thread)
{
    auto item = s_CurrentApplication->m_threadList.Insert(GetThreadId(thread));
    if (item->threadType == nullptr)
    {
        item->Fill(thread);
    }
    return item;
}



int TargetApplication::GetMaximumThreadPriority()
{
    return s_MaximumThreadPriority;
}


bool TargetApplication::FindModuleName(
    uintptr_t baseaddr,
    uintptr_t endaddr,
    const char*& pOutName,
    size_t& pOutNameLength)
{
    NN_SDK_ASSERT(baseaddr <= endaddr);

    const int RWRegionAlignment = 0x1000;

    pOutName = nullptr;
    pOutNameLength = 0;

    // Look for the module name at the start of the RO section
    {
        uint64_t tempValue = *reinterpret_cast<uint64_t*>(baseaddr);
        uint64_t value = (tempValue << 32) | (tempValue >> 32);
        if (value <= MaximumFilePathLength)
        {
            pOutNameLength = static_cast<size_t>(value);
            pOutName = reinterpret_cast<const char*>(baseaddr + 8);
            return true;
        }
    }

    // Look for the module name before the .note.gnu.build-id section
    {
        const uint32_t buildIdHeader = GnuBuildId::GnuId;
        const size_t buildIdHeaderU32Count = 4;

        uint32_t* searchAddr = reinterpret_cast<uint32_t*>(endaddr - 4);
        uint32_t* endSearchAddr = reinterpret_cast<uint32_t*>(endaddr - RWRegionAlignment - 24);
        if (endSearchAddr < reinterpret_cast<uint32_t*>(baseaddr))
        {
            endSearchAddr = reinterpret_cast<uint32_t*>(baseaddr);
        }

        while (searchAddr >= endSearchAddr && *searchAddr != buildIdHeader)
        {
            --searchAddr;
        }
        searchAddr -= buildIdHeaderU32Count;

        if (searchAddr < endSearchAddr) { return false; }

        uint8_t* moduleIdFinder = reinterpret_cast<uint8_t*>(searchAddr) + 3;
        uint8_t* endIdSearchAddr = reinterpret_cast<uint8_t*>(searchAddr) - MaximumFilePathLength - 8;
        if (endIdSearchAddr < reinterpret_cast<uint8_t*>(baseaddr))
        {
            endIdSearchAddr = reinterpret_cast<uint8_t*>(baseaddr);
        }

        while (moduleIdFinder >= endIdSearchAddr && *moduleIdFinder == 0)
        {
            --moduleIdFinder;
        }

        if (moduleIdFinder < endIdSearchAddr) { return false; }

        uint8_t* endModuleName = moduleIdFinder;

        while (moduleIdFinder >= endIdSearchAddr && *moduleIdFinder != 0)
        {
            --moduleIdFinder;
        }

        if (moduleIdFinder - 7 < endIdSearchAddr) { return false; }

        uint8_t* startModuleName = moduleIdFinder + 1;
        uint64_t* moduleIdReader = reinterpret_cast<uint64_t*>(moduleIdFinder - 7);
        uint64_t calculatedNameLength = static_cast<uint64_t>(endModuleName - startModuleName + 1);

        uint64_t tempValue = *moduleIdReader;
        uint64_t storedNameLength = (tempValue << 32) | (tempValue >> 32);
        if (storedNameLength <= MaximumFilePathLength && storedNameLength == calculatedNameLength)
        {
            pOutNameLength = static_cast<size_t>(storedNameLength);
            pOutName = reinterpret_cast<const char*>(moduleIdReader + 1);
            return true;
        }
    }

    return false;
}


bool TargetApplication::FindModuleBuildId(uintptr_t baseaddr, uintptr_t endaddr, void* gnuBuildId)
{
    GnuBuildId* buildId = static_cast<GnuBuildId*>(gnuBuildId);

    NN_SDK_ASSERT(baseaddr <= endaddr);

    const int RWRegionAlignment = 0x1000;

    const uint32_t buildIdHeader = GnuBuildId::GnuId;

    uint32_t* searchAddr = reinterpret_cast<uint32_t*>(endaddr - 4);
    uint32_t* endSearchAddr = reinterpret_cast<uint32_t*>(endaddr - RWRegionAlignment - 24);
    if (endSearchAddr < reinterpret_cast<uint32_t*>(baseaddr))
    {
        endSearchAddr = reinterpret_cast<uint32_t*>(baseaddr);
    }

    while (searchAddr >= endSearchAddr && *searchAddr != buildIdHeader)
    {
        --searchAddr;
    }
    searchAddr -= 3;

    if (searchAddr < endSearchAddr) { return false; }

    buildId->Fill(searchAddr);

    return true;
}



/*****************************************************************************
    Class Definitions
 *****************************************************************************/
TargetApplication::TargetApplication() :
    m_threadList()
{
    FindCodeRegions();
}



TargetApplication::~TargetApplication()
{
}



void TargetApplication::GetApplicationName(const char*& pName, size_t& pLength) const
{
    if (nn::os::GetHostArgc() > 0)
    {
        pName = nn::os::GetHostArgv()[0];
        int length = nn::util::Strnlen(pName, MaximumFilePathLength);
        if (length < 0)
        {
            pName = nullptr;
            pLength = 0;
            return;
        }
        else if (length > 0)
        {
            pLength = static_cast<size_t>(length);
            return;
        }
        // else (length == 0)
        //   Try looking for the module name in the code regions.
        //   For whatever reason, system processes have an argc of 1, but no real argv contents.
        //   Maybe so that command line arguments can be passed in without breaking paradigm that
        //   argv[0] is the application name?
    }

    if (m_staticCodeRegionCount > 0)
    {
        int count = m_staticCodeRegionCount;
        for (auto& x : gPastModules)
        {
            if (count <= 0) { break; }
            --count;

            int nameLength = nn::util::Strnlen(x.Name, MaximumFilePathLength);
            if (nameLength > 4)
            {
                int compareValue = nn::util::Strnicmp(
                    &x.Name[nameLength - 4],
                    ".nss",
                    4);
                if (compareValue == 0)
                {
                    pName = x.Name;
                    pLength = static_cast<size_t>(nameLength);
                    return;
                }
            }
        }
    }

    pName = nullptr;
    pLength = 0;
    return;
}



void TargetApplication::FindCodeRegions()
{
    ClearModuleLists();
    BuildStaticModules();
    FindStaticCodeRegions();
    BuildActiveModules();
    FindDynamicCodeRegions();
}



void TargetApplication::FindStaticCodeRegions()
{
    m_staticCodeRegionCount = 0;
    m_minStaticCodeAddress = UINT64_MAX;
    m_maxStaticCodeAddress = 0;

    for (auto& x : gPastModules)
    {
        m_minStaticCodeAddress = std::min(m_minStaticCodeAddress, x.Address);
        m_maxStaticCodeAddress = std::max(m_maxStaticCodeAddress, x.Address + x.Size);
    }

    m_staticCodeRegionCount = gPastModules.size();
    m_codeRegionCount = m_staticCodeRegionCount;
}



void TargetApplication::FindDynamicCodeRegions()
{
    m_minCodeAddress = m_minStaticCodeAddress;
    m_maxCodeAddress = m_maxStaticCodeAddress;

    for (auto&x : gActiveModules)
    {
        m_minCodeAddress = std::min(m_minCodeAddress, x.Address);
        m_maxCodeAddress = std::max(m_maxCodeAddress, x.Address + x.Size);
    }

    m_codeRegionCount = m_staticCodeRegionCount + gActiveModules.size();
}



uintptr_t TargetApplication::GetMinCodeAddress() const
{
    NN_STATIC_ASSERT(sizeof(uintptr_t) <= sizeof(m_minCodeAddress));
    return static_cast<uintptr_t>(m_minCodeAddress);
}



uintptr_t TargetApplication::GetMaxCodeAddress() const
{
    NN_STATIC_ASSERT(sizeof(uintptr_t) <= sizeof(m_maxCodeAddress));
    return static_cast<uintptr_t>(m_maxCodeAddress);
}



int TargetApplication::GetCodeRegionCount() const
{
    return m_codeRegionCount;
}



bool TargetApplication::GetStackStartFromThreadId(
    nn::os::ThreadType* thread,
    uintptr_t *stackBase) const
{
    size_t stackSize;
    uintptr_t stackTop;
    if (thread != nullptr)
    {
        nn::os::GetThreadStackInfo(&stackTop, &stackSize, thread);
        if (stackBase != nullptr) { *stackBase = stackTop + stackSize; }
        return !(stackTop == 0 || stackSize == 0);
    }
    return false;
}



void TargetApplication::ClearThreadList()
{
    m_threadList.Clear();
}


ThreadListItem* TargetApplication::RegisterThread(nn::os::ThreadType* thread)
{
    auto item = m_threadList.Insert(GetThreadId(thread));
    if (item->threadType == nullptr)
    {
        item->Fill(thread);
    }
    return item;
}


void TargetApplication::RegisterAllThreads()
{
    nn::os::ThreadType* startThread = nn::os::GetCurrentThread();
    nn::os::ThreadType *next = startThread;
    size_t loopCount = 0;
    do
    {
        if ((next->_state == nn::os::ThreadType::State_Initialized) ||
            (next->_state == nn::os::ThreadType::State_Started) ||
            (next->_state == nn::os::ThreadType::State_Exited))
        {
            RegisterThread(next);
        }
        auto n = reinterpret_cast<ThreadNode*>(&next->_allThreadsListNode)->next;
        if (next == n) { break; }
        next = n;
        ++loopCount;
    } while (next != startThread && next != nullptr && loopCount < ThreadIdListSize);
}


void PrintThreadInfo(nn::os::ThreadType* thread)
{
    // The call to nn::os::GetThreadNamePointer will ABORT unless the thread is runnable.
    // Since we don't want to crash, perform the same check out here before calling.
    if ((thread->_state == nn::os::ThreadType::State_Initialized) ||
        (thread->_state == nn::os::ThreadType::State_Started) ||
        (thread->_state == nn::os::ThreadType::State_Exited))
    {
        const char* threadName = nn::os::GetThreadNamePointer(thread);
        int priority = nn::os::GetThreadCurrentPriority(thread);
        INFO_LOG("Thread: %p, p%d, %s\n", thread, priority, threadName);
    }
}



ThreadList* TargetApplication::GetThreadList()
{
#if (LOG_AS_INFO >= MINIMUM_LOG_TIER)
    INFO_LOG("Thread List\n");
    INFO_LOG("--\n");

    nn::os::ThreadType* startThread = nn::os::GetCurrentThread();
    PrintThreadInfo(startThread);

    ThreadNode* node = reinterpret_cast<ThreadNode*>(&startThread->_allThreadsListNode);

    {
        nn::os::ThreadType *next = node->next;
        while (next != startThread)
        {
            auto n = reinterpret_cast<ThreadNode*>(&next->_allThreadsListNode)->next;
            PrintThreadInfo(next);
            if (next == n) { break; }
            next = n;
        }
    }

    INFO_LOG("--\n");
#endif

    return &m_threadList;
}



uint32_t TargetApplication::GetSdkVersion() const
{
    return this->m_sdkVersion;
}



void TargetApplication::SetSdkVersion(uint32_t version)
{
    this->m_sdkVersion = version;
}



} // profiler
} // nn
