﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/nn_Common.h>
#include <nn/nn_Windows.h>
#include <nn/nn_SdkAssert.h>
#include <nn/diag/diag_Module.h>
#include <nn/util/util_ScopeExit.h>
#include <nn/util/util_StringUtil.h>

#include <psapi.h>
#pragma comment(lib, "psapi.lib")

namespace nn { namespace diag { namespace detail {

namespace
{
    wchar_t* ToWideCharString(const char* string, size_t length) NN_NOEXCEPT
    {
        const auto stringSize = sizeof(char) * length;

        const auto wideCharStringLength = MultiByteToWideChar(CP_UTF8, 0, string, static_cast<int>(stringSize), NULL, 0);

        const auto wideCharStringSize = sizeof(wchar_t) * (wideCharStringLength + 1);

        const auto pWideCharString = reinterpret_cast<LPWSTR>(std::malloc(wideCharStringSize));
        NN_SDK_ASSERT_NOT_NULL(pWideCharString);

        const auto result = MultiByteToWideChar(CP_UTF8, 0, string, static_cast<int>(stringSize), pWideCharString, static_cast<int>(wideCharStringSize));
        NN_SDK_ASSERT_EQUAL(result, wideCharStringLength);
        NN_UNUSED(result);

        pWideCharString[wideCharStringLength] = L'\0';

        return pWideCharString;
    }
}

uintptr_t GetModulePathImpl(char* outPathBuffer, size_t pathBufferSize, uintptr_t address) NN_NOEXCEPT
{
    const auto bufferSize = nn::diag::GetRequiredBufferSizeForGetAllModuleInfo();
    if (!(bufferSize > 0))
    {
        NN_SDK_ASSERT(false);
        return 0u;
    }

    auto buffer = reinterpret_cast<nn::Bit8*>(malloc(bufferSize));
    NN_UTIL_SCOPE_EXIT { free(buffer); };
    NN_SDK_ASSERT_NOT_NULL(buffer);

    nn::diag::ModuleInfo* modules;
    const auto moduleCount = nn::diag::GetAllModuleInfo(&modules, buffer, bufferSize);
    if (!(moduleCount > 0))
    {
        NN_SDK_ASSERT(false);
        return 0u;
    }

    for (auto i = 0; i < moduleCount; i++)
    {
        const auto& info = modules[i];
        if (info.baseAddress <= address && address < info.baseAddress + info.size)
        {
            nn::util::Strlcpy(outPathBuffer, info.path, static_cast<int>(pathBufferSize));
            return info.baseAddress;
        }
    }

    return 0u;
}

int GetAllModuleInfoImpl(size_t* outRequiredBufferSize, void* outBuffer, size_t bufferSize, bool skipSave) NN_NOEXCEPT
{
    const HANDLE hCurrentProcess = GetCurrentProcess();
    if (hCurrentProcess == NULL)
    {
        NN_SDK_ASSERT(false);
        return 0;
    }

    DWORD requiredModulesSize = 0u;

    EnumProcessModules(hCurrentProcess, NULL, 0u, &requiredModulesSize);
    if (!(requiredModulesSize > 0u))
    {
        NN_SDK_ASSERT(false);
        return 0;
    }

    const int actualModuleCount = requiredModulesSize / sizeof(HMODULE);
    auto hModules = reinterpret_cast<HMODULE*>(malloc(requiredModulesSize));
    NN_UTIL_SCOPE_EXIT { free(hModules); };

    if (!EnumProcessModules(hCurrentProcess, hModules, requiredModulesSize, &requiredModulesSize))
    {
        NN_SDK_ASSERT(false);
        return 0;
    }

    int moduleCount = 0;
    size_t requiredBufferSize = 0u;
    auto pInfo = reinterpret_cast<ModuleInfo*>(outBuffer);
    auto pPath = reinterpret_cast<char*>(outBuffer) + bufferSize; // パス文字列はバッファの後ろから詰める。

    for (auto i = 0; i < actualModuleCount; i++)
    {
        TCHAR tcharPath[MAX_PATH];
        const auto tcharPathLength = GetModuleFileNameEx(hCurrentProcess, hModules[i], tcharPath, MAX_PATH);
        if (tcharPathLength == 0)
        {
            NN_SDK_ASSERT(false);
            return 0;
        }

    #ifdef UNICODE
        const auto multibytePathLength = WideCharToMultiByte(CP_UTF8, 0, tcharPath, tcharPathLength, NULL, 0, NULL, NULL);
        auto multibytePath = reinterpret_cast<char*>(malloc(multibytePathLength));
        NN_SDK_ASSERT_NOT_NULL(multibytePath);
        NN_UTIL_SCOPE_EXIT { free(multibytePath); };

        if (WideCharToMultiByte(CP_UTF8, 0, tcharPath, tcharPathLength, multibytePath, multibytePathLength, NULL, NULL) != multibytePathLength)
        {
            NN_SDK_ASSERT(false);
            return 0;
        }

        const char* path = multibytePath;
        const size_t pathLength = multibytePathLength;
    #else
        const char* path = tcharPath;
        const size_t pathLength = tcharPathLength;
    #endif

        requiredBufferSize += sizeof(ModuleInfo) + pathLength + 1;

        if (!skipSave)
        {
            if (requiredBufferSize > bufferSize)
            {
                NN_SDK_ASSERT_LESS(moduleCount, actualModuleCount);
                return moduleCount;
            }

            MODULEINFO info;
            if (!GetModuleInformation(hCurrentProcess, hModules[i], &info, sizeof(info)))
            {
                NN_SDK_ASSERT(false);
                return 0;
            }

            pPath -= pathLength + 1;
            std::memcpy(pPath, path, pathLength);
            pPath[pathLength] = '\0';

            pInfo->path = pPath;
            pInfo->baseAddress = reinterpret_cast<uintptr_t>(info.lpBaseOfDll);
            pInfo->size = static_cast<size_t>(info.SizeOfImage);
            pInfo++;
        }

        moduleCount++;
    }

    *outRequiredBufferSize = requiredBufferSize;
    NN_SDK_ASSERT_EQUAL(moduleCount, actualModuleCount);
    return moduleCount;
}

bool GetReadOnlyDataSectionRangeImpl(uintptr_t* outStartAddress, uintptr_t* outEndAddress, uintptr_t baseAddress) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL(outStartAddress);
    NN_SDK_REQUIRES_NOT_NULL(outEndAddress);

    char path[_MAX_PATH + 1];
    auto actualBaseAddress = GetModulePathImpl(path, sizeof(path), baseAddress);
    if (!(actualBaseAddress > 0u && baseAddress == actualBaseAddress))
    {
        return false;
    }

    auto wideCharPath = ToWideCharString(path, util::Strnlen(path, sizeof(path)));
    NN_UTIL_SCOPE_EXIT
    {
        std::free(wideCharPath);
    };

    auto hFile = CreateFile(wideCharPath, GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0);
    if (hFile == INVALID_HANDLE_VALUE)
    {
        return false;
    };
    NN_UTIL_SCOPE_EXIT
    {
        CloseHandle(hFile);
    };

    auto hFileMapping = CreateFileMapping(hFile, NULL, PAGE_READONLY, 0, 0, NULL);
    if (hFileMapping == 0)
    {
        return false;
    }
    NN_UTIL_SCOPE_EXIT
    {
        CloseHandle(hFileMapping);
    };

    auto lpFileBase = MapViewOfFile(hFileMapping, FILE_MAP_READ, 0, 0, 0);
    if (lpFileBase == 0)
    {
        return false;
    }
    NN_UTIL_SCOPE_EXIT
    {
        UnmapViewOfFile(lpFileBase);
    };

    auto pDosHeader = reinterpret_cast<PIMAGE_DOS_HEADER>(lpFileBase);
    if (pDosHeader->e_magic != IMAGE_DOS_SIGNATURE)
    {
        return false;
    }

    auto pNtHeader = reinterpret_cast<PIMAGE_NT_HEADERS>(reinterpret_cast<BYTE*>(pDosHeader) + pDosHeader->e_lfanew);
    if (pNtHeader->Signature != IMAGE_NT_SIGNATURE)
    {
        return false;
    }

    auto pSecionHeaders = IMAGE_FIRST_SECTION(pNtHeader);

    for (int i = 0; i < pNtHeader->FileHeader.NumberOfSections; i++)
    {
        auto pSecionHeader = pSecionHeaders + i;

        if (util::Strncmp((char*)pSecionHeader->Name, ".rdata", sizeof(".rdata") - 1) == 0)
        {
            *outStartAddress = baseAddress + pSecionHeader->VirtualAddress;
            *outEndAddress = baseAddress + pSecionHeader->VirtualAddress + pSecionHeader->Misc.VirtualSize;
            return true;
        }
    }

    return false;
}

}}} // nn::diag::detail
