﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include "Network.h"
#include <nn/nn_Log.h>
#include <nn/os/os_Tick.h>

#if defined(__INTELLISENSE__) && __INTELLISENSE__  // intellisence は clang の #include_next <poll.h> を処理してくれない
#include <poll.h>
#endif
#include <nn/socket/socket_Api.h>

#include <vector>


// ICMP 関連の定義が欲しい
namespace { namespace posix {
#include <netinet/ip_icmp.h>
} }


namespace ApConnectivityTest
{

// キャンセルオブジェクトコンストラクタ
Network::PingController::PingController()
{
}


// キャンセルオブジェクトコンストラクタ
Network::PingController::PingController(const std::shared_ptr<int>& socket) :
        m_Socket(socket)
{
}


// Ping キャンセル
void Network::PingController::Cancel()
{
    auto socket = m_Socket.lock();
    if (socket)
    {
        auto escape = *socket;
        *socket = -3;
        nn::socket::Close(escape);
    }
}


// コンストラクタ
Network::PingTask::PingTask(const std::string& dest, size_t length, size_t count, const std::function<void()>& callback) :
        Network::Task(Network::TaskType_Ping),
        m_Dest(dest),
        m_DataLength(length),
        m_TotalCount(count),
        m_Callback(callback),
        m_Socket(new int(-2))
{
}


// デストラクタ
Network::PingTask::~PingTask()
{
}


// ping 送受信処理
void Network::PingTask::Run()
{
    // ネットワーク接続が閉じている
    if (!Network::GetInstance().m_pConnection)
    {
        return;
    }

    // 初回 ping 初期化処理
    if (*m_Socket == -2)
    {
        auto host = nn::socket::GetHostEntByName(m_Dest.c_str());  // NSD が働いててもよい？
        if (!host)
        {
            auto pErrno = nn::socket::GetHError();
            NN_LOG("errno: %s\n", nn::socket::HStrError(*pErrno));
            return;
        }
        m_DestHost = host->h_name;
        m_DestIp = *reinterpret_cast<nn::socket::InAddr*>(*host->h_addr_list);

        NN_LOG("Ping send to %s(%s)\n", m_DestHost.c_str(), nn::socket::InetNtoa(m_DestIp));

        m_IcmpId = static_cast<uint16_t>(nn::os::GetSystemTick().GetInt64Value());
        m_pSendTimerEvent = std::shared_ptr<nn::os::TimerEvent>(new nn::os::TimerEvent(nn::os::EventClearMode_AutoClear));
        m_pSendTimerEvent->Signal();

        m_SendBuffer.resize(sizeof(posix::icmphdr) + m_DataLength);

        m_CurrentCount = 0;
        m_SuccessCount = 0;

        // raw ソケットを開く
        *m_Socket = nn::socket::Socket(nn::socket::Family::Af_Inet, nn::socket::Type::Sock_Raw, nn::socket::Protocol::IpProto_Icmp);
        if (m_Socket < 0)
        {
            IcmpResult result;
            result.type = IcmpResultType_CriticalError;
            m_Callback();

            return;
        }
    }

    // 送信間隔タイマー待ち
    m_pSendTimerEvent->Wait();

    // ICMP ヘッダを作成
    posix::icmphdr* sendIcmpHeader = reinterpret_cast<posix::icmphdr*>(m_SendBuffer.data());
    sendIcmpHeader->type = ICMP_ECHO;
    sendIcmpHeader->code = 0;
    sendIcmpHeader->checksum = 0;
    sendIcmpHeader->un.echo.id = m_IcmpId;
    sendIcmpHeader->un.echo.sequence = m_CurrentCount;

    // ICMP データを作成
    uint8_t* sendIcmpData = m_SendBuffer.data() + sizeof(posix::icmphdr);
    for (int i = 0; i < m_DataLength; ++i)
    {
        sendIcmpData[i] = static_cast<uint8_t>(i + sendIcmpHeader->un.echo.id);
    }

    // チェックサムを計算
    {
        uint32_t checksum = 0;

        auto size = m_SendBuffer.size();
        if (size & 1)
        {
            checksum += m_SendBuffer[--size];
        }

        size /= 2;
        while (size--)
        {
            checksum += reinterpret_cast<uint16_t*>(m_SendBuffer.data())[size];
        }

        checksum = (checksum & 0xffff) + (checksum >> 16);
        checksum = (checksum & 0xffff) + (checksum >> 16);
        sendIcmpHeader->checksum = static_cast<uint16_t>(~checksum);
    }

    // ICMP 送信先アドレスを取得
    nn::socket::AddrInfo addrInfoHint = {};
    addrInfoHint.ai_family = nn::socket::Family::Pf_Inet;
    addrInfoHint.ai_socktype = nn::socket::Type::Sock_Raw;
    addrInfoHint.ai_protocol = nn::socket::Protocol::IpProto_Icmp;
    addrInfoHint.ai_flags = nn::socket::AddrInfoFlag::Ai_Passive;
    nn::socket::AddrInfo* addrInfo;
    nn::socket::GetAddrInfo(nn::socket::InetNtoa(m_DestIp), nullptr, &addrInfoHint, &addrInfo);

    // ICMP Echo 要求送信
    auto pingSendTick = nn::os::GetSystemTick();
    m_pSendTimerEvent->StartOneShot(nn::TimeSpan::FromSeconds(1));
    if (nn::socket::SendTo(*m_Socket, m_SendBuffer.data(), m_SendBuffer.size(), nn::socket::MsgFlag::Msg_None, addrInfo->ai_addr, addrInfo->ai_addrlen) < 0)
    {
        if (*m_Socket == -3)
        {
            PrintStatistics();
            m_Callback();
        }
        else
        {
            PrintSocketError();
            nn::socket::FreeAddrInfo(addrInfo);
            Network::GetInstance().m_CommandQueue.Push(std::shared_ptr<Task>(new PingTask(*this)));
        }
        return;
    }

    nn::socket::FreeAddrInfo(addrInfo);

    // 受信
    std::vector<uint8_t> receiveBuffer;
    IcmpResult icmpResult;
    int receiveResult = Receive(&receiveBuffer, 5000);
    if (receiveResult < 0)
    {
        if (*m_Socket == -3)
        {
            PrintStatistics();
            m_Callback();
        }
        else
        {
            PrintSocketError();
            Network::GetInstance().m_CommandQueue.Push(std::shared_ptr<Task>(new PingTask(*this)));
        }
        return;
    }
    else if (!receiveResult)
    {
        // タイムアウト
        icmpResult.type = IcmpResultType_Timeout;
    }
    else
    {
        auto pingReceiveTick = nn::os::GetSystemTick();

        auto& icmpHeader = *reinterpret_cast<posix::icmphdr*>(receiveBuffer.data() + sizeof(posix::iphdr));
        switch (icmpHeader.type)
        {
            // エコー応答
        case ICMP_ECHOREPLY:
            {
                auto& sendIcmpHeader = *reinterpret_cast<posix::icmphdr*>(m_SendBuffer.data());
                if (icmpHeader.un.echo.id == sendIcmpHeader.un.echo.id && icmpHeader.un.echo.sequence == sendIcmpHeader.un.echo.sequence)
                {
                    auto& ipHeader = *reinterpret_cast<posix::iphdr*>(receiveBuffer.data());

                    icmpResult.type = IcmpResultType_EchoReply;
                    icmpResult.elapsedTime = (pingReceiveTick - pingSendTick).ToTimeSpan();
                    icmpResult.responceFrom.S_addr = ipHeader.saddr;
                    icmpResult.ttl = ipHeader.ttl;
                    icmpResult.sequence = icmpHeader.un.echo.sequence;
                    icmpResult.dataLength = ipHeader.tot_len - (sizeof(posix::iphdr) + sizeof(posix::icmphdr));
                }
            }
            break;

            // エラー系応答
        case ICMP_DEST_UNREACH:
        case ICMP_TIME_EXCEEDED:
        case ICMP_PARAMETERPROB:
        case ICMP_SOURCE_QUENCH:
        case ICMP_REDIRECT:
            if (!std::memcmp(reinterpret_cast<posix::icmphdr*>(receiveBuffer.data() + sizeof(posix::iphdr)), m_SendBuffer.data(), sizeof(posix::icmphdr)))
            {
                auto& ipHeader = *reinterpret_cast<posix::iphdr*>(receiveBuffer.data());

                icmpResult.type = static_cast<IcmpResultType>(icmpHeader.type);
                icmpResult.elapsedTime = (pingReceiveTick - pingSendTick).ToTimeSpan();
                icmpResult.responceFrom.S_addr = ipHeader.saddr;
                icmpResult.sequence = icmpHeader.un.echo.sequence;
            }
            break;

        default:
            break;
        }
    }
    PrintResponce(icmpResult);

    // 次の ping 送受信コマンドをキューに追加
    if (++m_CurrentCount < m_TotalCount)
    {
        Network::GetInstance().m_CommandQueue.Push(std::shared_ptr<Task>(new PingTask(*this)));
    }
    // 指定回数終了
    else
    {
        PrintStatistics();
        m_Callback();

        nn::socket::Close(*m_Socket);
    }
} // NOLINT(impl/function_size)


Network::PingController Network::PingTask::GetContoller()
{
    return PingController(m_Socket);
}


// icmp echo 受信
int Network::PingTask::Receive(std::vector<uint8_t>* pOutBuffer, int timeout)
{
    auto beginTick = nn::os::GetSystemTick();

    int sockOptValue;
    nn::socket::PollFd pollFd = {};

    // IP ヘッダを受信
    sockOptValue = sizeof(posix::iphdr);
    if (nn::socket::SetSockOpt(*m_Socket, nn::socket::Level::Sol_Socket, nn::socket::Option::So_SndLoWat, &sockOptValue, sizeof(sockOptValue)) < 0)
    {
        return -1;
    }

    pollFd.fd = *m_Socket;
    pollFd.events = nn::socket::PollEvent::PollIn;

    posix::iphdr ipHeader;
    if (nn::socket::Poll(&pollFd, 1, timeout) <= 0)
    {
        return 0;
    }

    if (nn::socket::Recv(*m_Socket, &ipHeader, sizeof(ipHeader), nn::socket::MsgFlag::Msg_Peek) < 0)
    {
        return -1;
    }

    // IP パケット全体を受信
    pOutBuffer->resize(ipHeader.tot_len);

    sockOptValue = ipHeader.tot_len;
    if (nn::socket::SetSockOpt(*m_Socket, nn::socket::Level::Sol_Socket, nn::socket::Option::So_SndLoWat, &sockOptValue, sizeof(sockOptValue)) < 0)
    {
        return -1;
    }

    timeout -= (nn::os::GetSystemTick() - beginTick).ToTimeSpan().GetMilliSeconds();
    if (nn::socket::Poll(&pollFd, 1, timeout < 0 ? 0 : timeout) <= 0)
    {
        return 0;
    }

    if (nn::socket::Recv(*m_Socket, pOutBuffer->data(), ipHeader.tot_len, nn::socket::MsgFlag::Msg_None) < 0)
    {
        return -1;
    }

    return 1;
}


// ログ出力
void Network::PingTask::PrintResponce(IcmpResult& result)
{
    auto srcIp = nn::socket::InetNtoa(result.responceFrom);

    // 結果を出力
    switch (result.type)
    {
        // Echo Reply
    case IcmpResultType_EchoReply:
        NN_LOG("response from %s: icmp_seq=%d ttl=%d %d bytes %d ms\n", srcIp, result.sequence, result.ttl, result.dataLength, result.elapsedTime.GetMilliSeconds());
        ++m_SuccessCount;
        break;

        // 宛先到達不可能
    case IcmpResultType_DestinationUnreachable:
        NN_LOG("response from %s: icmp_seq=%d Destination Unreachable\n", srcIp, result.sequence);
        break;

        // 時間超過
    case IcmpResultType_TimeExceeded:
        NN_LOG("response from %s: icmp_seq=%d Time Exceeded\n", srcIp, result.sequence);
        break;

        // パラメータ異常
    case IcmpResultType_ParameterProblem:
        NN_LOG("response from %s: icmp_seq=%d Parameter Problem\n", srcIp, result.sequence);
        break;

        // 発信抑制
    case IcmpResultType_SourceQuench:
        NN_LOG("response from %s: icmp_seq=%d Source Quench\n", srcIp, result.sequence);
        break;

        // ルート変更
    case IcmpResultType_Redirect:
        NN_LOG("response from %s: icmp_seq=%d Redirect (change route)\n", srcIp, result.sequence);
        break;

        // タイムアウト
    case IcmpResultType_Timeout:
        NN_LOG("Timeout\n");
        break;

    default:
        break;
    }
}


void Network::PingTask::PrintStatistics()
{
    NN_LOG("--- ping %s(%s) statistics ---\n", m_DestHost.c_str(), nn::socket::InetNtoa(m_DestIp));
    NN_LOG("send %d packets, recieve %d packets, loss %.2f%%\n\n",
            m_CurrentCount,
            m_SuccessCount,
            static_cast<float>((m_CurrentCount - m_SuccessCount) * 100) / m_CurrentCount);
}


void Network::PingTask::PrintSocketError()
{
    auto errNo = nn::socket::GetLastError();
    NN_LOG("errno: %d\n", errNo);
}

}
