﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
/*
* Copyright (c) 2015 NVIDIA Corporation.  All rights reserved.
*
* NVIDIA Corporation and its licensors retain all intellectual property
* and proprietary rights in and to this software, related documentation
* and any modifications thereto.  Any use, reproduction, disclosure or
* distribution of this software and related documentation without an express
* license agreement from NVIDIA Corporation is strictly prohibited.
*/

#include <atomic>
#include <cstring>

#include <nn/nn_Common.h>
#include <nn/nn_SdkAssert.h>
#include <nn/os.h>
#include <nn/htcs.h>
#include <nn/diag/diag_Module.h>
#include <nn/ro/detail/ro_RoModule.h>

#include "profiler_Logging.h"
#include "profiler_Memory.h"
#include "profiler_RecordMethods.h"
#include "profiler_TargetApplication.h"
#include "profiler_Svc.autogen.h"
#include "profiler_MemoryServer.h"

#include "profiler_Defines.h"

namespace nn { namespace profiler {


namespace /*anonymous*/
{
    const size_t ThreadStackSize = 24 * 1024;
    const char *NVMEMORY_PORT_NAME = "@NvMemoryProfiler";

    struct
    {
        nn::os::ThreadType transmitThread;
        void* threadStack;

        SharedArea* sharedArea;
        char* sendBuffer;
        nn::os::EventType dataToSendEvent;
        nn::os::EventType dataSentEvent;
        int listenSocket;
        int connectionSocket;

        std::atomic<bool> isInitialized;
        bool shouldExit;
        bool startTriggered;
        bool connected;
    } globals;


    int Send(int connFd, void *packet, size_t packetsize)
    {
        char *buf;
        size_t bytes;
        nn::htcs::ssize_t ret;

        for (buf = (char *)packet, bytes = 0; bytes < packetsize; bytes += size_t(ret))
        {
            ret = nn::htcs::Send(connFd, buf + bytes, packetsize - bytes, 0);
            if (ret == 0)
            {
                ERROR_LOG("received EOF when writing request\n");
                return -1;
            }
            else if (ret < 0)
            {
                if (nn::htcs::GetLastError() == nn::htcs::HTCS_EINTR)
                {
                    ret = 0;
                    continue;
                }
                else
                {
                    ERROR_LOG("failed to write packet (%d)\n", nn::htcs::GetLastError());
                    return -1;
                }
            }
        }
        return 0;
    }


    template <int N>
    static bool AddModuleToBuffer(Buffer<N>& buf, nn::svc::MemoryInfo& x)
    {
        if (buf.size + MaximumFilePathLength + 16 > static_cast<size_t>(buf.GetCapacity()))
        {
            return false;
        }

        nn::svc::MemoryInfo memInfo;
        nn::svc::PageInfo pageInfo;
        nn::Result result;
        NN_UNUSED(result);

        buf.template Write<uint64_t>(x.baseAddress);
        buf.template Write<uint64_t>(x.size);

        // First method that makes use of built-in SDK features.
        // However, it has issues with certain NSS/NRS builds.
        //auto moduleaddr = nn::diag::GetModulePath(&buf.buf[buf.size], MaximumFilePathLength, memInfo.baseAddress);
        //if (moduleaddr != 0)
        //{
        //    size_t moduleNameLength = strnlen(&buf.buf[buf.size], MaximumFilePathLength);
        //    buf.size += moduleNameLength;
        //}
        //buf.template Write<char>(0);

        uintptr_t readOnlyCodeAddress = static_cast<uintptr_t>(x.baseAddress + x.size);
        result = nn::svc::profiler::QueryMemory(&memInfo, &pageInfo, readOnlyCodeAddress);
        NN_SDK_ASSERT(result.IsSuccess());
        const char* pName;
        size_t nameLength;
        TargetApplication::FindModuleName(
            static_cast<uintptr_t>(memInfo.baseAddress),
            static_cast<uintptr_t>(memInfo.baseAddress + memInfo.size),
            pName,
            nameLength);
        if (pName != nullptr && nameLength > 0)
        {
            memcpy(&buf.buf[buf.size], pName, nameLength);
            buf.size += nameLength;
        }
        buf.template Write<char>(0);

        return true;
    }


    template <int N>
    static void WriteModulesToBuffer(Buffer<N>& buf)
    {
        uint32_t moduleCount = 0;
        char* pModuleCount = &buf.buf[buf.size];
        buf.template Write<uint32_t>(0);

        uintptr_t addr = 0;
        while (NN_STATIC_CONDITION(true))
        {
            nn::svc::MemoryInfo memInfo;
            nn::svc::PageInfo pageInfo;
            nn::Result result = nn::svc::profiler::QueryMemory(&memInfo, &pageInfo, addr);
            if (result.IsFailure())
            {
                ERROR_LOG("Failed to query address: %p\n", addr);
                DumpResultInformation(LOG_AS_ERROR, result);
                break;
            }

            if (memInfo.permission == nn::svc::MemoryPermission_ReadExecute &&
                (memInfo.state == nn::svc::MemoryState_AliasCode || memInfo.state == nn::svc::MemoryState_Code))
            {
                // Found a module
                if (!AddModuleToBuffer(buf, memInfo)) { break; }
                ++moduleCount;
            }
            else if (memInfo.state == nn::svc::MemoryState_Inaccessible ||
                static_cast<uintptr_t>(memInfo.baseAddress + memInfo.size) <= addr)
            {
                break;
            }

            addr = static_cast<uintptr_t>(memInfo.baseAddress + memInfo.size);
        }

        memcpy(pModuleCount, &moduleCount, sizeof(uint32_t));
    }


    void TransmitThreadFunc(void* arg)
    {
        NN_UNUSED(arg);

        while (!globals.shouldExit)
        {
            globals.listenSocket = nn::htcs::Socket();
            if (globals.listenSocket < 0)
            {
                auto lastError = nn::htcs::GetLastError();
                if (lastError != nn::htcs::HTCS_ENETDOWN)
                {
                    ERROR_LOG("failed to create socket, error %d\n", nn::htcs::GetLastError());
                    return;
                }
                else
                {
                    nn::os::SleepThread(nn::TimeSpan::FromSeconds(1));
                    continue;
                }
            }

            nn::htcs::SockAddrHtcs addr;
            addr.family = nn::htcs::HTCS_AF_HTCS;
            addr.peerName = nn::htcs::GetPeerNameAny();
            strcpy(addr.portName.name, NVMEMORY_PORT_NAME);

            {
                int ret;

                ret = nn::htcs::Bind(globals.listenSocket, &addr);
                if (ret < 0)
                {
                    ERROR_LOG("failed to bind listening fd to port, error %d\n", nn::htcs::GetLastError());
                    continue;
                }

                ret = nn::htcs::Listen(globals.listenSocket, 10);
                if (ret < 0)
                {
                    ERROR_LOG("failed to setup listening fd, error %d\n", nn::htcs::GetLastError());
                    continue;
                }
            }

            while (!globals.shouldExit)
            {
                INFO_LOG("Waiting to connect memory profiler in %s...\n", NVMEMORY_PORT_NAME);

                globals.connectionSocket = nn::htcs::Accept(globals.listenSocket, &addr);
                if (globals.connectionSocket < 0)
                {
                    ERROR_LOG("failed to accept connection, error %d\n", nn::htcs::GetLastError());
                    break;
                }

                globals.connected = true;

                // Handshake
                const char name[] = "NX";
                const size_t len = strlen(name);

                const int InitBufferCapacity = 16 * 512 + 256;
                Buffer<InitBufferCapacity> buf;
                buf.Write<uint8_t>(0);
                buf.Write<uint8_t>(1);
                buf.Write<uint16_t>(0); // opcode
                char* packetLength = &buf.buf[buf.size];
                buf.Write<uint16_t>(0); // length of init packet
                buf.Write<uint32_t>(3); // version
                buf.Write<uint32_t>(static_cast<uint16_t>(len)); // length of name
                buf.Write(name, len); // client name

                WriteModulesToBuffer(buf);

                ptrdiff_t initPacketLength = &buf.buf[buf.size] - packetLength - 2;
                NN_SDK_ASSERT(initPacketLength > 0);
                NN_SDK_ASSERT(initPacketLength < (1 << (sizeof(uint16_t) * 8)));
                uint16_t writePacketLength = static_cast<uint16_t>(initPacketLength);
                memcpy(packetLength, &writePacketLength, sizeof(writePacketLength));

                auto sendResult = Send(globals.connectionSocket, buf.buf, buf.size);
                if (sendResult < 0)
                {
                    break;
                }

                // Receive ack
                struct
                {
                    uint16_t zero;
                    uint16_t len;
                    uint16_t numCallstacks;
                } ack;
                auto recvAmt = nn::htcs::Recv(globals.connectionSocket, &ack, sizeof(ack), nn::htcs::HTCS_MSG_WAITALL);
                if (recvAmt <= 0)
                {
                    ERROR_LOG("Failed to receive notification from PC\n");
                    break;
                }
                NN_SDK_ASSERT(recvAmt == sizeof(ack));
                NN_SDK_ASSERT(ack.zero == 0);
                NN_SDK_ASSERT(ack.len == 2);
                INFO_LOG("Connected memory profiler.\n");

                globals.sharedArea->numCallstacks = ack.numCallstacks;

                while (!globals.shouldExit)
                {
                    nn::os::WaitEvent(&globals.dataToSendEvent);
                    nn::os::ClearEvent(&globals.dataToSendEvent);
                    if (globals.shouldExit) { break; }

                    size_t packetsize;
                    uint32_t get = globals.sharedArea->get;
                    uint32_t put = globals.sharedArea->put;
                    char* ring = globals.sharedArea->ring;

                    if (get == put)
                    {
                        continue;
                    }

                    std::atomic_thread_fence(std::memory_order_acquire);

                    // Send everything at once.  Copy to a local buffer because for some
                    // reason transmitting directly from shared mem doesn't work.
                    char *temp = globals.sendBuffer;
                    NN_SDK_ASSERT((int)(SharedMemorySizeMax) >= RingSize);
                    packetsize = put > get ? put - get : RingSize - get;
                    memcpy(temp, ring + get, packetsize);
                    if (put < get)
                    {
                        memcpy(temp + packetsize, ring, put);
                        packetsize += put;
                    }

                    if (Send(globals.connectionSocket, temp, packetsize) < 0)
                    {
                        INFO_LOG("Disconnected memory profiler.\n");
                        nn::htcs::Close(globals.connectionSocket);
                        break;
                    }

                    std::atomic_thread_fence(std::memory_order_release);
                    globals.sharedArea->get = put;
                }

                nn::htcs::Close(globals.connectionSocket);
                globals.connected = false;
            }

            nn::htcs::Close(globals.listenSocket);
            globals.connected = false;
        }
    } // NOLINT(impl/function_size)

} // anonymous


void InitializeMemoryProfilerServer(SharedArea* data)
{
    if (globals.isInitialized == false)
    {
        globals.sharedArea = data;

        globals.sendBuffer = reinterpret_cast<char*>(Memory::GetInstance()->Allocate(SharedMemorySizeMax));

        globals.shouldExit = false;
        globals.threadStack = Memory::GetInstance()->Allocate(ThreadStackSize, nn::os::ThreadStackAlignment);
        nn::os::CreateThread(
            &globals.transmitThread,
            TransmitThreadFunc,
            nullptr,
            globals.threadStack,
            ThreadStackSize,
            nn::os::HighestThreadPriority + 1);
        nn::os::SetThreadName(&globals.transmitThread, "NX CPU Profiler: Memory Server");
        nn::os::StartThread(&globals.transmitThread);

        nn::os::InitializeEvent(&globals.dataToSendEvent, globals.startTriggered, nn::os::EventClearMode_AutoClear);
        nn::os::InitializeEvent(&globals.dataSentEvent, false, nn::os::EventClearMode_AutoClear);

        globals.isInitialized = true;
    }
}


bool IsConnectedToMemoryProfilerServer()
{
    return globals.connected;
}


void SendMemoryEventToPc()
{
    if (globals.isInitialized)
    {
        nn::os::SignalEvent(&globals.dataToSendEvent);
    }
    else
    {
        globals.startTriggered = false;
    }
}


void FinalizeMemoryProfilerServer()
{
    globals.shouldExit = true;
    if (globals.isInitialized.exchange(false) == true)
    {
        nn::htcs::Close(globals.connectionSocket);
        nn::htcs::Close(globals.listenSocket);

        nn::os::SignalEvent(&globals.dataToSendEvent);

        nn::os::WaitThread(&globals.transmitThread);
        nn::os::DestroyThread(&globals.transmitThread);
        Memory::GetInstance()->Free(globals.threadStack);

        Memory::GetInstance()->Free(globals.sendBuffer);

        nn::os::FinalizeEvent(&globals.dataToSendEvent);
        nn::os::FinalizeEvent(&globals.dataSentEvent);

        memset(&globals, 0, sizeof(globals));
    }
}


}}
