﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <memory>
#include <vector>
#include <string>
#include <algorithm>

#include <nn/fs.h>
#include <nn/nn_Log.h>
#include <nn/nn_SdkLog.h>
#include <nn/init.h>
#include <nn/os.h>
#include <nn/settings/factory/settings_SerialNumber.h>
#include <nn/settings/factory/settings_ConfigurationId.h>
#include <nn/settings/system/settings_FirmwareVersion.h>
#include <nn/util/util_ScopeExit.h>

#include "RecoveryWriterUsb.h"

#define ERROR_LOG_F(x, ...) NN_LOG("[ERROR] " x "", __VA_ARGS__)
#define ERROR_LOG(...) ERROR_LOG_F("%s", __VA_ARGS__)

#define SUCCESS_LOG_F(x, ...) NN_LOG("[SUCCESS] " x "", __VA_ARGS__)
#define SUCCESS_LOG(...) SUCCESS_LOG_F("%s", __VA_ARGS__)

namespace {
    const GUID TargetGuid = { 0x97FFFD48, 0x2D1D, 0x47A0,{ 0x85, 0xA3, 0x07, 0xFD, 0xE6, 0xFA, 0x01, 0x43 } };

    enum ExitCode : int
    {
        ExitCode_Success = 0, //!< 成功を表す終了コード
        ExitCode_Failure,     //!< 失敗を表す終了コード
    };

    enum ErrorCode : int
    {
        ErrorCode_Success,  //!< 成功を表すエラーコード
        ErrorCode_Failure,  //!< 失敗を表すエラーコード
    };

    const uint32_t TimeOut = 5000;
    const int MaxDeviceNum = 1024;

    struct Command final
    {
        const char* name;

        ErrorCode(*function)(UsbDeviceInfo*, ::std::vector<::std::string>&);
    };

    const wchar_t recoveryUsbString[] = L"recovery usb";

    bool Compare(const ::nn::settings::factory::SerialNumber& a, const ::nn::settings::factory::SerialNumber& b)
    {
        return strcmp(a.string, b.string) < 0;
    }
}

ErrorCode CheckArgumentsCount(::std::vector<std::string>& args, int size) NN_NOEXCEPT
{
    if (args.size() != static_cast<size_t>(size))
    {
        ERROR_LOG_F("Invalid arguments count: %d\n", args.size() - 1);
        return ErrorCode_Failure;
    }

    return ErrorCode_Success;
}

void MakeDsHeader(DsHeader* pHeader, int bytesSend) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL(pHeader);

    pHeader->begin = 0x12345678;
    pHeader->bytes = bytesSend;
    pHeader->end = 0xabcdef;
}

ErrorCode SendData(ULONG* pBytesTransferred, UsbDeviceInfo* pDeviceInfo, void* pInBuffer, uint32_t sendSize) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL(pBytesTransferred);
    NN_SDK_REQUIRES_NOT_NULL(pDeviceInfo);
    NN_SDK_REQUIRES_NOT_NULL(pInBuffer);

    if (WinUsb_WritePipe(
        pDeviceInfo->winUsbHandle,
        pDeviceInfo->pipeBulkOut,
        reinterpret_cast<uint8_t *>(pInBuffer),
        sendSize,
        pBytesTransferred,
        NULL
        ) == false)
    {
        return ErrorCode_Failure;
    }

    if (*pBytesTransferred != sendSize)
    {
        return ErrorCode_Failure;
    }

    return ErrorCode_Success;
}

ErrorCode SendCommand(UsbDeviceInfo* pDeviceInfo, ::std::string commandName) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL(pDeviceInfo);

    ULONG bytesTransferred = 0;
    uint32_t bytesSend = static_cast<uint32_t>(commandName.size());

    DsHeader header;
    MakeDsHeader(&header, bytesSend);

    if (SendData(&bytesTransferred, pDeviceInfo, &header, sizeof(header)) != ErrorCode_Success)
    {
        return ErrorCode_Failure;
    }

    const int MaxCommandLength = 5;
    char dataBuffer[MaxCommandLength];
    dataBuffer[0] = commandName[0];

    if (SendData(&bytesTransferred, pDeviceInfo, dataBuffer, bytesSend) != ErrorCode_Success)
    {
        return ErrorCode_Failure;
    }

    return ErrorCode_Success;
}

ErrorCode ReceiveData(::nn::Bit8* pInBuffer, uint32_t bufferSize, UsbDeviceInfo* pDeviceInfo) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL(pInBuffer);
    NN_SDK_REQUIRES_NOT_NULL(pDeviceInfo);

    ULONG bytesTransferred = 0;
    DsHeader header;

    if (WinUsb_ReadPipe(
        pDeviceInfo->winUsbHandle,
        pDeviceInfo->pipeBulkIn,
        reinterpret_cast<uint8_t *>(&header),
        sizeof(DsHeader),
        &bytesTransferred,
        NULL
        ) == false)
    {
        return ErrorCode_Failure;
    }

    if (bytesTransferred != sizeof(DsHeader))
    {
        return ErrorCode_Failure;
    }

    if ((header.begin != 0x12345678) || (header.end != 0xabcdef))
    {
        return ErrorCode_Failure;
    }

    if (header.bytes > bufferSize)
    {
        return ErrorCode_Failure;
    }

    if (WinUsb_ReadPipe(
        pDeviceInfo->winUsbHandle,
        pDeviceInfo->pipeBulkIn,
        pInBuffer,
        header.bytes,
        &bytesTransferred,
        NULL
        ) == false)
    {
        return ErrorCode_Failure;
    }

    return ErrorCode_Success;
}

bool IsSafeModeDevice(WINUSB_INTERFACE_HANDLE winUsbHandle) NN_NOEXCEPT
{
    ULONG bytesWritten = 0;
    ::std::unique_ptr<char[]> buffer(new char[MAXIMUM_USB_STRING_LENGTH]);
    auto stringDescriptor = reinterpret_cast<PUSB_STRING_DESCRIPTOR>(buffer.get());

    UCHAR descriptorNum = 0;
    while (NN_STATIC_CONDITION(true))
    {
        if (!WinUsb_GetDescriptor(winUsbHandle,
            USB_STRING_DESCRIPTOR_TYPE,
            descriptorNum,
            LANG_ENGLISH,
            reinterpret_cast<PUCHAR>(stringDescriptor),
            MAXIMUM_USB_STRING_LENGTH,
            &bytesWritten))
        {
            break;
        }

        ++descriptorNum;

        stringDescriptor->bString[(stringDescriptor->bLength - 2) / 2] = L'\0';

        if (wcscmp(recoveryUsbString, stringDescriptor->bString) == 0)
        {
            return true;
        }
    }

    return false;
}

bool GetTargetSerialNumber(UsbDeviceInfo* pDeviceInfo,
                          ::nn::settings::factory::SerialNumber& serialNumber)
{
    NN_SDK_REQUIRES_NOT_NULL(pDeviceInfo);
    if (SendCommand(pDeviceInfo, "1") != ErrorCode_Success)
    {
        return false;
    }

    if (ReceiveData(reinterpret_cast<::nn::Bit8 *>(&serialNumber),
        sizeof(::nn::settings::factory::SerialNumber),
        pDeviceInfo) != ErrorCode_Success)
    {
        return false;
    }

    return true;
}

ErrorCode IsTargetDevice(UsbDeviceInfo* pOutDeviceInfo,
                         HANDLE* deviceHandle,
                         ::std::string targetSerialNumber,
                         ::nn::settings::factory::SerialNumber* safeModeDevice,
                         int* deviceNum) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL(pOutDeviceInfo);
    NN_SDK_REQUIRES_NOT_NULL(deviceHandle);
    NN_SDK_REQUIRES_NOT_NULL(deviceNum);

    WINUSB_INTERFACE_HANDLE winUsbHandle;
    if (!WinUsb_Initialize(*deviceHandle, &winUsbHandle))
    {
        CloseHandle(*deviceHandle);
        return ErrorCode_Failure;
    }

    if (!IsSafeModeDevice(winUsbHandle))
    {
        return ErrorCode_Failure;
    }

    uint8_t pipeBulkIn = 0;
    uint8_t pipeBulkOut = 0;

    USB_INTERFACE_DESCRIPTOR interfaceDescriptor;
    if (!WinUsb_QueryInterfaceSettings(winUsbHandle, 0, &interfaceDescriptor))
    {
        return ErrorCode_Failure;
    }

    for (int i = 0; i < interfaceDescriptor.bNumEndpoints; i++)
    {
        WINUSB_PIPE_INFORMATION pipeInfo;
        while (!WinUsb_QueryPipe(winUsbHandle, 0, static_cast<uint8_t>(i), &pipeInfo))
        {
            Sleep(10);
        }
        if (pipeInfo.PipeType == UsbdPipeTypeBulk && USB_ENDPOINT_DIRECTION_IN(pipeInfo.PipeId))
        {
            pipeBulkIn = pipeInfo.PipeId;
        }
        else if (pipeInfo.PipeType == UsbdPipeTypeBulk && USB_ENDPOINT_DIRECTION_OUT(pipeInfo.PipeId))
        {
            pipeBulkOut = pipeInfo.PipeId;
        }

        uint32_t timeOut = TimeOut;
        WinUsb_SetPipePolicy(winUsbHandle,
            pipeInfo.PipeId,
            PIPE_TRANSFER_TIMEOUT,
            sizeof(uint32_t),
            &timeOut);
    }

    UsbDeviceInfo info;
    info.interfaceIndex = interfaceDescriptor.bInterfaceNumber;
    info.deviceHandle = *deviceHandle;
    info.winUsbHandle = winUsbHandle;
    info.pipeBulkIn = pipeBulkIn;
    info.pipeBulkOut = pipeBulkOut;

    ::nn::settings::factory::SerialNumber serialNumber;
    if (!GetTargetSerialNumber(&info, serialNumber))
    {
        CloseHandle(*deviceHandle);
        WinUsb_Free(winUsbHandle);
        return ErrorCode_Failure;
    }

    safeModeDevice[*deviceNum] = serialNumber;
    (*deviceNum)++;

    //対象のシリアルナンバーと同一
    if (serialNumber.string == targetSerialNumber)
    {
        *pOutDeviceInfo = info;
        return ErrorCode_Success;
    }

    return ErrorCode_Failure;
}

ErrorCode EnumerateDevice(UsbDeviceInfo* pOutDeviceInfo,
                          ::std::string targetSerialNumber,
                          ::nn::settings::factory::SerialNumber* safeModeDevices,
                          int* deviceNum) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL(deviceNum);
    NN_SDK_REQUIRES_NOT_NULL(pOutDeviceInfo);
    bool found = false;

    HDEVINFO deviceInfoHandle = SetupDiGetClassDevs(&TargetGuid,
                                                    NULL,
                                                    NULL,
                                                    DIGCF_PRESENT | DIGCF_DEVICEINTERFACE);

    if (!deviceInfoHandle)
    {
        return ErrorCode_Failure;
    }

    SP_DEVICE_INTERFACE_DATA interfaceData = { 0 };
    interfaceData.cbSize = sizeof(SP_DEVICE_INTERFACE_DATA);

    DWORD dwIndex = 0;

    while (NN_STATIC_CONDITION(true))
    {
        if (SetupDiEnumDeviceInterfaces(deviceInfoHandle,
                                        NULL,
                                        &TargetGuid,
                                        dwIndex,
                                        &interfaceData))
        {
            ULONG interfaceDetailDataLength = 256;

            if (!SetupDiGetDeviceInterfaceDetail(deviceInfoHandle,
                &interfaceData,
                NULL,
                0,
                &interfaceDetailDataLength,
                NULL))
            {
                if (GetLastError() != ERROR_INSUFFICIENT_BUFFER)
                {
                    return ErrorCode_Failure;
                }
            }

            ::std::unique_ptr<char[]> buffer(new char[interfaceDetailDataLength]);
            auto pInterfaceDetailData = reinterpret_cast<PSP_DEVICE_INTERFACE_DETAIL_DATA>(buffer.get());
            pInterfaceDetailData->cbSize = sizeof(SP_DEVICE_INTERFACE_DETAIL_DATA);

            if (!SetupDiGetDeviceInterfaceDetail(deviceInfoHandle,
                &interfaceData,
                pInterfaceDetailData,
                interfaceDetailDataLength,
                &interfaceDetailDataLength,
                NULL))
            {
                return ErrorCode_Failure;
            }

            HANDLE deviceHandle = NULL;
            deviceHandle = CreateFile(pInterfaceDetailData->DevicePath,
                GENERIC_WRITE | GENERIC_READ,
                FILE_SHARE_WRITE | FILE_SHARE_READ,
                NULL,
                OPEN_EXISTING,
                FILE_ATTRIBUTE_NORMAL | FILE_FLAG_OVERLAPPED,
                NULL);

            if (deviceHandle == INVALID_HANDLE_VALUE)
            {
                return ErrorCode_Failure;
            }

            if (IsTargetDevice(pOutDeviceInfo, &deviceHandle, targetSerialNumber, safeModeDevices, deviceNum) == ErrorCode_Success)
            {
                found = true;
            }
        }
        else
        {
            if (GetLastError() == ERROR_NO_MORE_ITEMS)
            {
                break;
            }
            else
            {
                return ErrorCode_Failure;
            }
        }

        dwIndex++;
    }

    if (found)
    {
        return ErrorCode_Success;
    }

    return ErrorCode_Failure;
}

void* AllocaterForFileSystem(size_t size) NN_NOEXCEPT
{
    return ::std::malloc(size);
}

void DeallocaterForFileSystem(void* addr, size_t) NN_NOEXCEPT
{
    ::std::free(addr);
}

ErrorCode DoRequestConfigurationId1Command(UsbDeviceInfo* pDeviceInfo, ::std::vector<::std::string>& args) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL(pDeviceInfo);

    if (CheckArgumentsCount(args, 3) != ErrorCode_Success)
    {
        return ErrorCode_Failure;
    }

    if (SendCommand(pDeviceInfo, "0") != ErrorCode_Success)
    {
        ERROR_LOG("Failed to request configurationId1\n");
        return ErrorCode_Failure;
    }

    ::nn::settings::factory::ConfigurationId1 configurationId1;
    if (ReceiveData(reinterpret_cast<::nn::Bit8 *>(&configurationId1),
        sizeof(::nn::settings::factory::ConfigurationId1),
        pDeviceInfo) != ErrorCode_Success)
    {
        ERROR_LOG("Failed to receive configurationId1\n");
        return ErrorCode_Failure;
    }

    SUCCESS_LOG_F("%s\n", configurationId1.string);
    return ErrorCode_Success;
}

ErrorCode DoRequestFirmwareVersionCommand(UsbDeviceInfo* pDeviceInfo, ::std::vector<::std::string>& args) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL(pDeviceInfo);

    if (CheckArgumentsCount(args, 3) != ErrorCode_Success)
    {
        return ErrorCode_Failure;
    }

    if (SendCommand(pDeviceInfo, "3") != ErrorCode_Success)
    {
        ERROR_LOG("Failed to request firmware version\n");
        return ErrorCode_Failure;
    }

    ::nn::settings::system::FirmwareVersion firmwareVersion;
    NN_UNUSED(firmwareVersion);

    const int size = sizeof(firmwareVersion.displayName);
    ::nn::Bit8 buffer[size];

    if (ReceiveData(buffer, size, pDeviceInfo) != ErrorCode_Success)
    {
        ERROR_LOG("Failed to receive firmware version\n");
        return ErrorCode_Failure;
    }

    SUCCESS_LOG_F("%s\n", reinterpret_cast<char *>(buffer));
    return ErrorCode_Success;
}

ErrorCode DoSendImageCommand(UsbDeviceInfo* pDeviceInfo, ::std::vector<::std::string>& args) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL(pDeviceInfo);

    if (CheckArgumentsCount(args, 4) != ErrorCode_Success)
    {
        return ErrorCode_Failure;
    }
    ::std::string relativePath = args[3];

    ::std::unique_ptr<TCHAR[]> wpath(new TCHAR[MAX_PATH]);
    ::std::unique_ptr<char[]> path(new char[MAX_PATH]);
    ::GetCurrentDirectoryW(MAX_PATH, wpath.get());

    ::WideCharToMultiByte(CP_UTF8, 0, wpath.get(), -1, path.get(), MAX_PATH, NULL, NULL);

    const ::std::string MountName = "host";

    ::nn::fs::MountHost(MountName.c_str(), path.get());
    NN_UTIL_SCOPE_EXIT{ ::nn::fs::Unmount(MountName.c_str()); };

    const std::string imagePath = MountName + ":/" + relativePath;

    ::nn::fs::FileHandle fileHandle;
    if (::nn::fs::OpenFile(&fileHandle,
        imagePath.c_str(),
        ::nn::fs::OpenMode_Read).IsFailure())
    {
        ERROR_LOG_F("Failed to open the file: %s%c%s\n", path.get(),'\\',relativePath.c_str());
        return ErrorCode_Failure;
    }

    NN_UTIL_SCOPE_EXIT{ ::nn::fs::CloseFile(fileHandle); };

    int64_t fileSize = 0;
    ::nn::fs::GetFileSize(&fileSize, fileHandle);

    int totalSize = static_cast<int>(fileSize);
    ::std::unique_ptr<char[]> imageStorage(new char[totalSize]);

    ::nn::fs::ReadFile(fileHandle, 0, imageStorage.get(), static_cast<size_t>(totalSize));

    //復旧イメージの確認用
    NN_SDK_LOG("ImageSize = %d\n", fileSize);
    NN_SDK_LOG("%x\n", *reinterpret_cast<unsigned int *>((imageStorage.get() + 0x10000)));
    NN_SDK_LOG("%x\n", *reinterpret_cast<unsigned int *>((imageStorage.get() + 0x10004)));
    NN_SDK_LOG("%x\n", *reinterpret_cast<unsigned int *>((imageStorage.get() + 0x10008)));
    NN_SDK_LOG("%x\n", *reinterpret_cast<unsigned int *>((imageStorage.get() + 0x1000C)));
    NN_SDK_LOG("%x\n", *reinterpret_cast<unsigned int *>((imageStorage.get() + 0x10010)));
    NN_SDK_LOG("%x\n", *reinterpret_cast<unsigned int *>((imageStorage.get() + 0x10014)));
    NN_SDK_LOG("%x\n", *reinterpret_cast<unsigned int *>((imageStorage.get() + 0x10018)));
    NN_SDK_LOG("%x\n", *reinterpret_cast<unsigned int *>((imageStorage.get() + 0x1001C)));

    if (SendCommand(pDeviceInfo, "2") != ErrorCode_Success)
    {
        ERROR_LOG("Failed to request recovery\n");
        return ErrorCode_Failure;
    }

    ULONG bytesTransferred = 0;
    DsHeader header;
    MakeDsHeader(&header, totalSize);

    //復旧イメージ全体のサイズを送る
    if (SendData(&bytesTransferred, pDeviceInfo, &header, sizeof(header)) != ErrorCode_Success)
    {
        ERROR_LOG("Failed to send Image\n");
        return ErrorCode_Failure;
    }

    if (SendData(&bytesTransferred, pDeviceInfo, imageStorage.get(), totalSize) != ErrorCode_Success)
    {
        ERROR_LOG("Failed to send Image\n");
        return ErrorCode_Failure;
    }

    SUCCESS_LOG("Send Image Finished\n");
    return ErrorCode_Success;
}

void PrintSerialNumber(::std::unique_ptr<::nn::settings::factory::SerialNumber[]>& devices, int deviceNum) NN_NOEXCEPT
{
    ::std::sort(devices.get(), devices.get() + deviceNum, Compare);

    SUCCESS_LOG_F("%d device(s) found\n", deviceNum);
    for (int i = 0; i < deviceNum; ++i)
    {
        NN_LOG("%s\n", devices.get()[i].string);
    }
}

ErrorCode DoRequestBatteryChargeCommand(UsbDeviceInfo* pDeviceInfo, ::std::vector<::std::string>& args) NN_NOEXCEPT
{
    NN_SDK_REQUIRES_NOT_NULL(pDeviceInfo);

    if (CheckArgumentsCount(args, 3) != ErrorCode_Success)
    {
        return ErrorCode_Failure;
    }

    if (SendCommand(pDeviceInfo, "4") != ErrorCode_Success)
    {
        ERROR_LOG("Failed to request battery charge\n");
        return ErrorCode_Failure;
    }

    const int bufferSize = 256;
    ::nn::Bit8 buffer[bufferSize];

    if (ReceiveData(buffer, bufferSize, pDeviceInfo) != ErrorCode_Success)
    {
        ERROR_LOG("Failed to request battery charge\n");
        return ErrorCode_Failure;
    }

    SUCCESS_LOG_F("%s\n", reinterpret_cast<char *>(buffer));
    return ErrorCode_Success;
}

const Command g_Commands[] =
{
    { "GetConfigurationId1", DoRequestConfigurationId1Command },
    { "SendImage", DoSendImageCommand },
    { "GetFirmwareVersion", DoRequestFirmwareVersionCommand },
    { "GetBatteryCharge", DoRequestBatteryChargeCommand },
};

extern "C" int nnMain()
{
    ::nn::fs::SetAllocator(AllocaterForFileSystem, DeallocaterForFileSystem);

    ::std::vector<::std::string> args(::nn::os::GetHostArgc());
    for (int i = 0; i < static_cast<int>(args.size()); i++)
    {
        args[i] = ::std::string(::nn::os::GetHostArgv()[i]);
    }

    ::std::string commandName;
    ::std::string targetSerialNumber;
    if (args.size() < 2)
    {
        ERROR_LOG("No Specified Command.\n");
        return ExitCode_Failure;
    }

    commandName = args[1];

    if (args.size() >= 3)
    {
        targetSerialNumber = args[2];
    }

    int deviceNum = 0;
    UsbDeviceInfo info;
    ::std::unique_ptr<::nn::settings::factory::SerialNumber[]> devices(new ::nn::settings::factory::SerialNumber[MaxDeviceNum]);
    if(EnumerateDevice(&info, targetSerialNumber, devices.get(), &deviceNum) != ErrorCode_Success)
    {
        if(targetSerialNumber.size() != 0)
        {
            ERROR_LOG("Not found\n");
            return ExitCode_Failure;
        }
    }

    if (commandName == "GetSerialNumber")
    {
        PrintSerialNumber(devices, deviceNum);

        return ExitCode_Success;
    }

    NN_UTIL_SCOPE_EXIT{ WinUsb_Free(info.winUsbHandle); };
    NN_UTIL_SCOPE_EXIT{ CloseHandle(info.deviceHandle); };

    for (const Command& command : g_Commands)
    {
        if (commandName == command.name)
        {
            if (command.function(&info, args) == ErrorCode_Success)
            {
                return ExitCode_Success;
            }
            else
            {
                return ExitCode_Failure;
            }
        }
    }

    ERROR_LOG_F("Invalid command: %s\n", args[1].c_str());
    return ExitCode_Failure;
}

