﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include "DevMenuCommand_NetworkUtil.h"

#include <nn/nn_Common.h>
#include <nn/nn_Assert.h>
#include <nn/nn_Macro.h>
#include <nn/nn_Log.h>
#include <nn/os.h>
#include <nn/socket.h>
#include <nn/util/util_FormatString.h>
#include <nn/util/util_ScopeExit.h>
#include <nn/util/util_StringUtil.h>

// #define DETAIL_LOG_ENABLED

#if defined (DETAIL_LOG_ENABLED)

#define DETAIL_LOG(...) NN_LOG(__VA_ARGS__)

#else

#define DETAIL_LOG(...)

#endif

namespace devmenuUtil { namespace network {

namespace
{
    // iphdr
    struct IpHeader
    {
#if defined (NN_BUILD_TARGET_PLATFORM_ENDIAN_LITTLE)
        unsigned int ihl:4;
        unsigned int version:4;
#else
        unsigned int version:4;
        unsigned int ihl:4;
#endif
        uint8_t tos;
        uint16_t tot_len;
        uint16_t id;
        uint16_t frag_off;
        uint8_t ttl;
        uint8_t protocol;
        uint16_t check;
        uint32_t saddr;
        uint32_t daddr;
    };

    // icmphdr
    struct IcmpHeader
    {
        uint8_t type;
        uint8_t code;
        uint16_t checksum;
        union
        {
            struct
            {
                uint16_t id;
                uint16_t sequence;
            } echo;
            uint32_t gateway;
            struct
            {
                uint16_t __unused;
                uint16_t mtu;
            } frag;
            uint8_t reserved[4];
        } un;
    };

    const int IcmpEchoReply  = 0;
    const int IcmpEcho       = 8;
    const int IcmpTimeExceed = 11;
}

namespace
{
    template <typename T>
    T GenerateRandomNumber() NN_NOEXCEPT
    {
        T value;
        nn::os::GenerateRandomBytes(&value, sizeof (value));

        return value;
    }

    nn::Bit16 CalculateNetworkChecksum(const void* pData, size_t size) NN_NOEXCEPT
    {
        NN_ASSERT_EQUAL((reinterpret_cast<uintptr_t>(pData) % 2), 0u);

        uint64_t checksum = 0;

        const nn::Bit16* p = reinterpret_cast<const nn::Bit16*>(pData);

        // 1 の補数和を計算する。
        while (size >= 2)
        {
            checksum += *p++;
            size -= 2;
        }
        if (size == 1)
        {
            checksum += *reinterpret_cast<const nn::Bit8*>(p);
        }
        while (checksum >= 0x10000)
        {
            checksum = (checksum & 0xFFFF) + (checksum >> 16);
        }

        // 最後に 1 の補数を返す。
        return static_cast<nn::Bit16>(~checksum);
    }

    bool Receive(bool* pOutIsTimedOut, ssize_t* pOutReceived, int sock, void* pBuffer, size_t size, nn::TimeSpan timeout) NN_NOEXCEPT
    {
        nn::socket::FdSet readFds;

        nn::socket::FdSetZero(&readFds);
        nn::socket::FdSetSet(sock, &readFds);

        nn::socket::TimeVal selectTimeout = {};

        selectTimeout.tv_sec = static_cast<long>(timeout.GetSeconds());

        timeout -= nn::TimeSpan::FromSeconds(timeout.GetSeconds());
        selectTimeout.tv_usec = static_cast<long>(timeout.GetMicroSeconds());

        int result = nn::socket::Select(1, &readFds, nullptr, nullptr, &selectTimeout);

        if (result < 0)
        {
            DETAIL_LOG("[ERROR] nn::socket::Select() failed. error = %d\n", nn::socket::GetLastError());
            return false;
        }
        // タイムアウト
        if (result == 0)
        {
            *pOutIsTimedOut = true;
            return false;
        }

        ssize_t received = nn::socket::Recv(sock, pBuffer, size, nn::socket::MsgFlag::Msg_None);

        if (received < 1)
        {
            DETAIL_LOG("[ERROR] nn::socket::Recv() failed. error = %d\n", nn::socket::GetLastError());
            return false;
        }

        *pOutIsTimedOut = false;
        *pOutReceived = received;

        return true;
    }
}

bool GetIpAddress(nn::socket::InAddr* pOut, const char* pName) NN_NOEXCEPT
{
    NN_ASSERT_NOT_NULL(pOut);
    NN_ASSERT_NOT_NULL(pName);

    if (nn::socket::InetAton(pName, pOut) == 1)
    {
        return true;
    }

    nn::socket::AddrInfo hints = {};
    nn::socket::AddrInfo* pResults = nullptr;

    hints.ai_family = nn::socket::Family::Af_Inet;
    hints.ai_socktype = nn::socket::Type::Sock_Raw;

    nn::socket::AiErrno aiErrno = nn::socket::GetAddrInfo(pName, nullptr, &hints, &pResults);

    NN_UTIL_SCOPE_EXIT
    {
        if (pResults)
        {
            nn::socket::FreeAddrInfo(pResults);
        }
    };

    if (aiErrno != nn::socket::AiErrno::EAi_Success)
    {
        DETAIL_LOG("[ERROR] nn::socket::GetAddrInfo(%s) failed. error = %s\n",
            pName, nn::socket::GAIStrError(aiErrno));
        return false;
    }

    *pOut = reinterpret_cast<nn::socket::SockAddrIn*>(pResults->ai_addr)->sin_addr;

    return true;
}

bool GetHostName(HostName* pOut, const nn::socket::InAddr& addr) NN_NOEXCEPT
{
    NN_ASSERT_NOT_NULL(pOut);

    nn::socket::HostEnt* pHostEntry = nn::socket::GetHostEntByAddr(&addr,
        sizeof (addr), nn::socket::Family::Af_Inet);

    if (!pHostEntry)
    {
        DETAIL_LOG("[ERROR] nn::socket::GetHostEntByAddr() failed. error = %s\n",
            nn::socket::HStrError(*nn::socket::GetHError()));
        return false;
    }

    nn::util::Strlcpy(pOut->value, pHostEntry->h_name, sizeof (pOut->value));

    return true;
}

PingResult Ping(bool* pOutIsReached, nn::socket::InAddr* pOutReplySourceAddr,
    const nn::socket::InAddr& targetAddr, int32_t numHops, nn::TimeSpan receiveTimeout) NN_NOEXCEPT
{
    NN_ASSERT_NOT_NULL(pOutIsReached);
    NN_ASSERT_NOT_NULL(pOutReplySourceAddr);
    NN_ASSERT_GREATER(numHops, 0);

    int sock = nn::socket::Socket(nn::socket::Family::Af_Inet,
        nn::socket::Type::Sock_Raw, nn::socket::Protocol::IpProto_Icmp);

    if (sock < 0)
    {
        DETAIL_LOG("[ERROR] nn::socket::Socket() failed. error = %d\n", nn::socket::GetLastError());
        return PingResult_InitializationError;
    }

    NN_UTIL_SCOPE_EXIT
    {
        nn::socket::Close(sock);
    };

    if (nn::socket::SetSockOpt(sock, nn::socket::Level::Sol_Ip, nn::socket::Option::Ip_Ttl, &numHops, sizeof (numHops)) != 0)
    {
        DETAIL_LOG("[ERROR] nn::socket::SetSockOpt(Ip_Ttl) failed. error = %d\n", nn::socket::GetLastError());
        return PingResult_InitializationError;
    }

    nn::socket::SockAddrIn sockAddr = {};

    sockAddr.sin_family = nn::socket::Family::Af_Inet;
    sockAddr.sin_addr = targetAddr;

    IcmpHeader header = {};

    header.type = IcmpEcho;
    header.code = 0;
    header.un.echo.id = GenerateRandomNumber<uint16_t>();
    header.un.echo.sequence = GenerateRandomNumber<uint16_t>();

    header.checksum = CalculateNetworkChecksum(&header, sizeof (header));

    ssize_t sent = nn::socket::SendTo(sock, &header, sizeof (header), nn::socket::MsgFlag::Msg_None,
        reinterpret_cast<nn::socket::SockAddr*>(&sockAddr), sizeof (sockAddr));

    if (sent < 1)
    {
        DETAIL_LOG("[ERROR] nn::socket::SendTo() failed. error = %d\n", nn::socket::GetLastError());
        return PingResult_SendError;
    }

    NN_ALIGNAS(8) nn::Bit8 buffer[1024];
    bool isTimedOut = false;
    ssize_t received = 0;

    if (!Receive(&isTimedOut, &received, sock, buffer, sizeof (buffer), receiveTimeout))
    {
        return isTimedOut ? PingResult_TimedOut : PingResult_ReceiveError;
    }

    if (received < sizeof (IpHeader) + sizeof (IcmpHeader))
    {
        return PingResult_InvalidResponseSize;
    }

    IpHeader* pIpHeader = reinterpret_cast<IpHeader*>(buffer);

    DETAIL_LOG("IpHeader::version    = %u\n",     pIpHeader->version);
    DETAIL_LOG("IpHeader::ihl        = %u\n",     pIpHeader->ihl);
    DETAIL_LOG("IpHeader::tos        = %u\n",     pIpHeader->tos);
    DETAIL_LOG("IpHeader::tot_len    = %u\n",     pIpHeader->tot_len);
    DETAIL_LOG("IpHeader::id         = %u\n",     pIpHeader->id);
    DETAIL_LOG("IpHeader::frag_off   = %u\n",     pIpHeader->frag_off);
    DETAIL_LOG("IpHeader::ttl        = %u\n",     pIpHeader->ttl);
    DETAIL_LOG("IpHeader::protocol   = %u\n",     pIpHeader->protocol);
    DETAIL_LOG("IpHeader::check      = 0x%04X\n", pIpHeader->check);

    nn::socket::InAddr sourceAddr = {};
    sourceAddr.S_addr = pIpHeader->saddr;

    nn::socket::InAddr destinationAddr = {};
    destinationAddr.S_addr = pIpHeader->daddr;

    DETAIL_LOG("IpHeader::saddr      = %u (%s)\n", pIpHeader->saddr, nn::socket::InetNtoa(sourceAddr));
    DETAIL_LOG("IpHeader::daddr      = %u (%s)\n", pIpHeader->daddr, nn::socket::InetNtoa(destinationAddr));
    NN_UNUSED(destinationAddr);

    if (!(pIpHeader->version == 4 && pIpHeader->ihl == sizeof (IpHeader) / 4 && pIpHeader->protocol == 1)) // ICMP protocol = 1
    {
        return PingResult_InvalidIpHeader;
    }

    IcmpHeader* pIcmpHeader = reinterpret_cast<IcmpHeader*>(&buffer[pIpHeader->ihl * 4]);

    DETAIL_LOG("IcmpHeader::type     = %u\n",     pIcmpHeader->type);
    DETAIL_LOG("IcmpHeader::code     = %u\n",     pIcmpHeader->code);
    DETAIL_LOG("IcmpHeader::checksum = 0x%04X\n", pIcmpHeader->checksum);
    DETAIL_LOG("IcmpHeader::id       = %u\n",     pIcmpHeader->un.echo.id);
    DETAIL_LOG("IcmpHeader::sequence = %u\n",     pIcmpHeader->un.echo.sequence);

    if (pIcmpHeader->type == IcmpEchoReply &&
        !(pIcmpHeader->un.echo.id == header.un.echo.id && pIcmpHeader->un.echo.sequence == header.un.echo.sequence))
    {
        return PingResult_InvalidIcmpHeader;
    }

    *pOutReplySourceAddr = sourceAddr;

    switch (pIcmpHeader->type)
    {
    case IcmpEchoReply:
        *pOutIsReached = true;
        return PingResult_Success;
    case IcmpTimeExceed:
        *pOutIsReached = false;
        return PingResult_Success;
    default:
        return PingResult_UnhandledIcmpType;
    }
} // NOLINT(impl/function_size)

const char* GetPingResultString(PingResult result) NN_NOEXCEPT
{
    switch (result)
    {
    case PingResult_Success:
        return "Success";
    case PingResult_TimedOut:
        return "TimedOut";
    case PingResult_InitializationError:
        return "InitializationError";
    case PingResult_SendError:
        return "SendError";
    case PingResult_ReceiveError:
        return "ReceiveError";
    case PingResult_InvalidResponseSize:
        return "InvalidResponseSize";
    case PingResult_InvalidIpHeader:
        return "InvalidIpHeader";
    case PingResult_InvalidIcmpHeader:
        return "InvalidIcmpHeader";
    case PingResult_UnhandledIcmpType:
        return "UnhandledIcmpType";
    default:
        return "(unknown)";
    }
}

}}
