﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <cstdlib>
#include <atomic>

#include <nn/os.h>
#include <nn/nn_Assert.h>
#include <nn/nn_SdkAssert.h>
#include <nn/nn_Common.h>
#include <nn/result/result_HandlingUtility.h>
#include <nn/socket.h>
#include <nn/util/util_StringUtil.h>
#include <nn/util/util_ScopeExit.h>

#include <nn/netdiag/netdiag_Result.h>
#include <nn/netdiag/netdiag_ResultPrivate.h>
#include <nn/netdiag/netdiag_NatApi.h>
#include <nn/netdiag/netdiag_NatTypes.h>
#include <nn/netdiag/netdiag_NatTypesPrivate.h>
#include <nn/netdiag/detail/netdiag_Log.h>

namespace nn { namespace netdiag { namespace detail {

namespace {

    // ソケットデスクリプタ
    int g_SocketDescriptor = -1;

    // NAT 判定サーバアドレス
    nn::socket::InAddr g_NncsServer[ 2 ];

    // 中断指示
    std::atomic<bool> g_IsInterrupted;

    // ソースポート
    const uint16_t StartEphemeralPort = 49152; // エフェメラルポートとして使う最小ポート番号
    const uint16_t EndEphemeralPort = 65535;   // エフェメラルポートとして使う最大ポート番号
    const int32_t MaxTryBindCount = 10000; // bind しなおす最大回数

    // 内部で使用する、パケット送信用のパラメータ
    struct SendPacketSet
    {
        TestId     testId;          // テスト識別子
        NncsServer server;          // 接続サーバ
        uint16_t   requestPort;     // リクエスト送信先ポート
        uint16_t   responsePort;    // レスポンス受信元ポート
        int32_t    requestInterval; // リクエスト送信間隔(ミリ秒) : 0 なら間隔をあけない
        int        sendTimes;       // 送信回数
    };
    SendPacketSet SendPacketSetForDummy       = { TestId_Dummy,       NncsServer_1, 33334,     0, 0, 5 };
    SendPacketSet SendPacketSetForSimplified1 = { TestId_Simplified1, NncsServer_1, 10025, 10025, 0, 5 };
    SendPacketSet SendPacketSetForSimplified2 = { TestId_Simplified2, NncsServer_1, 10025, 10225, 0, 5 };
    SendPacketSet SendPacketSetForSimplified3 = { TestId_Simplified3, NncsServer_2, 10025, 10025, 0, 5 };

}

//----------------------------------------------------------------
// UDP ソケットをクローズ
void CloseSocket() NN_NOEXCEPT
{
    if ( g_SocketDescriptor < 0 )
    {
        return;
    }

    if ( nn::socket::Close( g_SocketDescriptor ) < 0 )
    {
        NN_DETAIL_NETDIAG_INFO("[netdiag] CloseSocket failed. (err=%d)", nn::socket::GetLastError());
        // socket::Close の失敗については、不正な socketdescriptor か
        // システムシャットダウンぐらいしかないので特にハンドリングしていない。
    }

    g_SocketDescriptor = -1;
}

//----------------------------------------------------------------
// UDP ソケットを作成
nn::Result CreateSocket( ResponseStat* pResponseStat, uint16_t usedPort ) NN_NOEXCEPT
{
    NN_RESULT_THROW_UNLESS(g_SocketDescriptor < 0, ResultAlreadySocketOpened());
    g_SocketDescriptor = nn::socket::Socket(nn::socket::Family::Af_Inet, nn::socket::Type::Sock_Dgram, nn::socket::Protocol::IpProto_Udp);
    if (g_SocketDescriptor < 0)
    {
        NN_DETAIL_NETDIAG_INFO("[netdiag] CreateSocket failed. (err=%d)", nn::socket::GetLastError());
        NN_RESULT_THROW(ResultFailedOpenSocket());
    }
    bool success = false;
    NN_UTIL_SCOPE_EXIT
    {
        if (!success)
        {
            CloseSocket();
        }
    };

    // 使用ポートを得る
    // ランダムですが、それほど厳密である必要はないのでチック値で算出しています。
    // 極端なことを言えば、毎回同じ値であっても問題はありません。
    uint16_t usingEphemeralPort = (nn::os::GetSystemTick().GetInt64Value() % ( EndEphemeralPort - StartEphemeralPort + 1 )) + StartEphemeralPort;

    int tryCount = 0;
    while (NN_STATIC_CONDITION(true))
    {
        // StartEphemeralPort ～ EndEphemeralPort で rotation
        if( usingEphemeralPort >= EndEphemeralPort )
        {
            usingEphemeralPort = StartEphemeralPort;
        }
        else
        {
            ++ usingEphemeralPort;
        }

        // ポート番号が使用済ポートと重なっていたら拒否
        if ( usingEphemeralPort == usedPort )
        {
            continue;
        }

        // 指定ポートへバインド
        nn::socket::SockAddrIn sin;
        sin.sin_family = nn::socket::Family::Af_Inet;
        sin.sin_port = nn::socket::InetHtons( usingEphemeralPort );
        sin.sin_addr.S_addr = nn::socket::InAddr_Any;

        NN_STATIC_ASSERT(sizeof(nn::socket::SockAddr) == sizeof(sin));
        int resultBind = nn::socket::Bind(g_SocketDescriptor, reinterpret_cast<nn::socket::SockAddr*>(&sin), sizeof(sin));
        if (resultBind == 0)
        {
            pResponseStat->sourcePort = usingEphemeralPort;
            success = true;

            NN_DETAIL_NETDIAG_INFO("[netdiag] CreateSocket successfully. (desc=%d) port:%u\n", g_SocketDescriptor, usingEphemeralPort);
            NN_RESULT_SUCCESS;
        }
        else
        {
            // ポートが使用中ならポート番号を増やして再トライ
            nn::socket::Errno lastErrNo = nn::socket::GetLastError();
            if (lastErrNo == nn::socket::Errno::EAddrInUse)
            {
                NN_DETAIL_NETDIAG_INFO("[netdiag] Bind socket failed because EADDRINUSE. (port:%u, tryCount:%d)\n", usingEphemeralPort, tryCount);
                // 一定試行していたらエラーとする
                NN_RESULT_THROW_UNLESS(++tryCount < MaxTryBindCount, ResultLackOfEphemeralPort());
                continue;
            }

            NN_DETAIL_NETDIAG_INFO( "[netdiag] Bind socket failed. (port:%u, resultBind:%d, err:%d, tryCount:%d)\n",
                usingEphemeralPort, resultBind, lastErrNo, tryCount);
            NN_RESULT_THROW(ResultFailedBindSocket());
        }
    }
    NN_ABORT("[netdiag] Unreachable");
}
//----------------------------------------------------------------
// サーバ情報を取得
nn::Result GetServerInfo( nn::socket::InAddr* pAddr, const char* serverName ) NN_NOEXCEPT
{
    // ヒント
    nn::socket::AddrInfo hints;
    memset( &hints, 0, sizeof(hints) );
    hints.ai_family = nn::socket::Family::Af_Inet;
    hints.ai_socktype = nn::socket::Type::Sock_Dgram;

    // サーバ情報取得
    nn::socket::AddrInfo* pResult = NULL;
    nn::socket::AiErrno resultGai = nn::socket::GetAddrInfo( serverName, NULL, &hints, &pResult );
    if ( resultGai != nn::socket::AiErrno::EAi_Success )
    {
        NN_DETAIL_NETDIAG_INFO( "[netdiag] Cannot get %s info. (%s (%d), err=%d)\n", serverName, nn::socket::GAIStrError(resultGai), resultGai, nn::socket::GetLastError() );
        return nn::netdiag::ResultFailedGetAddrInfo();
    }
    else if ( ! pResult )
    {
        NN_DETAIL_NETDIAG_INFO( "[netdiag] Cannot get %s info. (null result)\n", serverName );
        return nn::netdiag::ResultFailedGetAddrInfoData();
    }

    nn::socket::SockAddrIn* pSin = reinterpret_cast<nn::socket::SockAddrIn*>(pResult->ai_addr);
    *pAddr = pSin->sin_addr;

    if ( pResult )
    {
        nn::socket::FreeAddrInfo( pResult );
    }

    NN_RESULT_SUCCESS;
}

//----------------------------------------------------------------
// サーバアドレスを取得
nn::Result GetServerAddress() NN_NOEXCEPT
{
    // サーバ 1
    NN_RESULT_DO( GetServerInfo( &g_NncsServer[ NncsServer_1 ], ServerNameNncs1 ) );
    NN_DETAIL_NETDIAG_INFO( "[netdiag] Server nncs1: %s\n", nn::socket::InetNtoa(g_NncsServer[ NncsServer_1 ]) );

    // サーバ 2
    NN_RESULT_DO( GetServerInfo( &g_NncsServer[ NncsServer_2 ], ServerNameNncs2 ) );
    NN_DETAIL_NETDIAG_INFO( "[netdiag] Server nncs2: %s\n", nn::socket::InetNtoa(g_NncsServer[ NncsServer_2 ]) );

    NN_RESULT_SUCCESS;
}

//----------------------------------------------------------------
// パケット送信
nn::Result SendPacket( SendPacketSet& packetSet ) NN_NOEXCEPT
{
    // 送信パケット
    RequestPacket request;
    memset( &request, 0, sizeof(request) );
    request.testId = nn::socket::InetHtonl( packetSet.testId );

    // 接続先
    nn::socket::SockAddrIn sin = {0};
    memset( &sin, 0, sizeof(sin) );
    sin.sin_addr = g_NncsServer[ packetSet.server ];
    sin.sin_port = nn::socket::InetHtons( packetSet.requestPort );
    sin.sin_family = nn::socket::Family::Af_Inet;

    for( int i=0; i<packetSet.sendTimes; i++ )
    {
        NN_STATIC_ASSERT(sizeof(nn::socket::SockAddr) == sizeof(sin));
        ssize_t sent = nn::socket::SendTo( g_SocketDescriptor, &request, sizeof(request), nn::socket::MsgFlag::Msg_None, reinterpret_cast<nn::socket::SockAddr*>(&sin), sizeof(sin) );
        if ( sent < 0 )
        {
            NN_DETAIL_NETDIAG_INFO( "[netdiag] SendTo failed. [testId=%d] (err=%d)\n", packetSet.testId, nn::socket::GetLastError() );
            return nn::netdiag::ResultFailedSendRequest();
        }
        NN_DETAIL_NETDIAG_INFO( "[netdiag] SendTo %d byte. [testId=%d]\n", sent, packetSet.testId );

        // 送信間隔
        if( packetSet.requestInterval )
        {
            nn::os::SleepThread( nn::TimeSpan::FromMilliSeconds( packetSet.requestInterval ) );
        }
    }
    NN_RESULT_SUCCESS;
}

// Dummy の送信
nn::Result SendDummy() NN_NOEXCEPT
{
    return SendPacket( SendPacketSetForDummy );
}
// Simplified1 の送信
nn::Result SendSimplified1() NN_NOEXCEPT
{
    return SendPacket( SendPacketSetForSimplified1 );
}
// Simplified2 の送信
nn::Result SendSimplified2() NN_NOEXCEPT
{
    return SendPacket( SendPacketSetForSimplified2 );
}
// Simplified3 の送信
nn::Result SendSimplified3() NN_NOEXCEPT
{
    return SendPacket( SendPacketSetForSimplified3 );
}

//----------------------------------------------------------------
// レスポンスの統計情報をクリア
void ClearResponseStat( ResponseStat* pResponseStat ) NN_NOEXCEPT
{
    memset( pResponseStat, 0, sizeof(ResponseStat) );
}

// レスポンスの統計情報を更新
ResponseResult UpdateResponseStat( ResponseStat* pResponseStat, ResponsePacket& response, nn::socket::SockAddrIn& saddr ) NN_NOEXCEPT
{
    char addrString[INET_ADDRSTRLEN];
    uint32_t previousReceivedFlag = pResponseStat->receivedFlag;

    // アドレスとポート
    nn::util::Strlcpy( addrString, nn::socket::InetNtoa(saddr.sin_addr), sizeof(addrString) );
    uint16_t responsePort = nn::socket::InetNtohs(saddr.sin_port);
    NN_DETAIL_NETDIAG_INFO("[netdiag] Response from %s:%d [testId=%d]\n", addrString, responsePort, nn::socket::InetNtohl(response.testId) );
    NN_UNUSED(responsePort);

    // テスト識別子
    uint16_t perceivedPort = static_cast<uint16_t>( nn::socket::InetNtohl(response.PerceivedPort) );
    switch( nn::socket::InetNtohl( response.testId ) )
    {
        case TestId_Simplified1:
            if( perceivedPort > 0 )
            {
                pResponseStat->receivedFlag |= ResponseReceived_Simplified1;
                pResponseStat->port[SimplifiedIndex_1] = perceivedPort;
            }
            break;
        case TestId_Simplified2:
            if( perceivedPort > 0 )
            {
                pResponseStat->receivedFlag |= ResponseReceived_Simplified2;
                pResponseStat->port[SimplifiedIndex_2] = perceivedPort;
            }
            break;
        case TestId_Simplified3:
            if( perceivedPort > 0 )
            {
                pResponseStat->receivedFlag |= ResponseReceived_Simplified3;
                pResponseStat->port[SimplifiedIndex_3] = perceivedPort;
            }
            break;
        default:
            return ResponseResult_Unrelated; // 関係ないパケットだった
            break;
    }

    // すべて受信したか
    if ( pResponseStat->receivedFlag == ResponseReceived_AllSimplified )
    {
        return ResponseResult_Finished;
    }

    // この受信で 1と3 が揃ったか
    if (previousReceivedFlag != ResponseReceived_MappingSimplified && pResponseStat->receivedFlag == ResponseReceived_MappingSimplified )
    {
        return ResponseResult_ReceivedMapping;
    }

    return ResponseResult_OnTheWay;
}

//----------------------------------------------------------------
// レスポンス受信
nn::Result ReceiveResponse( ResponseStat* pResponseStat ) NN_NOEXCEPT
{
    const int64_t receiveTimeoutWhole     = 2500; // 全体のタイムアウト
    const int64_t receiveTimeoutFiltering = 1000; // あと Simplified2(フィルタリング判定) だけになったときのタイムアウト

    // スタート時点のチック
    nn::os::Tick startTick = nn::os::GetSystemTick();
    int64_t receiveTimeout = receiveTimeoutWhole;

    while( NN_STATIC_CONDITION(true) )
    {
        nn::socket::SockAddrIn saddr;
        nn::socket::SockLenT saddrlen = sizeof(saddr);
        ResponsePacket response;
        ssize_t recvSize = nn::socket::RecvFrom(g_SocketDescriptor,
                                                &response, sizeof(response),
                                                nn::socket::MsgFlag::Msg_DontWait,
                                                reinterpret_cast<nn::socket::SockAddr*>(&saddr), &saddrlen );
        if ( recvSize < 0 )
        {
            nn::socket::Errno lastErrno = nn::socket::GetLastError();
            if ( lastErrno != nn::socket::Errno::EAgain )
            {
                NN_DETAIL_NETDIAG_INFO("[netdiag] RecvFrom failed. (errno=%d)\n", lastErrno );
                return nn::netdiag::ResultFailedReceiveResponse();
            }
        }
        else if ( recvSize == sizeof(ResponsePacket) )
        {
            // レスポンス解析結果更新
            ResponseResult result = UpdateResponseStat( pResponseStat, response, saddr );
            if ( result == ResponseResult_Finished ) // 全完了
            {
                break;
            }
            else if ( result == ResponseResult_ReceivedMapping ) // マッピング関連(1と3) が揃った
            {
                // タイムアウトまでが receiveTimeoutFiltering 以上なら早める
                if ( (nn::os::GetSystemTick() - startTick).ToTimeSpan().GetMilliSeconds() > receiveTimeoutFiltering )
                {
                    startTick = nn::os::GetSystemTick();
                    receiveTimeout = receiveTimeoutFiltering;
                }
            }
        }

        // タイムアウト、または中断指示が出ているなら抜ける
        if ( (nn::os::GetSystemTick() - startTick).ToTimeSpan().GetMilliSeconds() > receiveTimeout || g_IsInterrupted )
        {
            break;
        }
        nn::os::SleepThread( nn::TimeSpan::FromMilliSeconds(1) );
    }

    NN_RESULT_SUCCESS;
}

//----------------------------------------------------------------
// 統計結果を元に NATタイプを決定する
// stat に統計結果が格納されている
NatType DetectNatTypeFromStat( ResponseStat stat[2] ) NN_NOEXCEPT
{
    // 結果の解析
    for( int index=0; index<2; index++ )
    {
        ResponseStat* p = &stat[index];
        // S1 または S3 の応答が得られていない場合は失敗
        p->isSuccess = ( p->receivedFlag & ResponseReceived_MappingSimplified) == ResponseReceived_MappingSimplified;

        if ( p->isSuccess )
        {
            // マッピングタイプ: S1 と S3 の port が同じなら EIM, そうでなければ EDM
            p->mappingType = ( p->port[SimplifiedIndex_1] == p->port[SimplifiedIndex_3] )?
                MappingType_Eim : MappingType_Edm;

            // フィルタリングタイプ: S2 の応答が得られなかったら PDF, そうでなければ PIF
            p->filteringType = (( p->receivedFlag &  ResponseReceived_Simplified2 ) == 0 )?
                FilteringType_Pdf: FilteringType_Pif;

            // ポート差分
            p->portDifference = static_cast<uint16_t>( std::abs( p->port[SimplifiedIndex_1] - p->port[SimplifiedIndex_3] ) );
        }
    }

#if 1
    // 解析結果の表示
    for( int index=0; index<2; index++ )
    {
        ResponseStat* p = &stat[index];
        NN_DETAIL_NETDIAG_INFO( "[netdiag] ---- NatTypeDetection :try %d\n", index + 1 );
        NN_DETAIL_NETDIAG_INFO( "[netdiag]  Result: %s\n", p->isSuccess? "SUCCESS": "FAILED" );
        NN_DETAIL_NETDIAG_INFO( "[netdiag]  UsingPort: %d\n", p->sourcePort );
        if ( p->isSuccess )
        {
            NN_DETAIL_NETDIAG_INFO( "[netdiag]  MappingType:    %s\n", p->mappingType == MappingType_Eim? "EIM": "EDM" );
            NN_DETAIL_NETDIAG_INFO( "[netdiag]  FilteringType:  %s\n", p->filteringType == FilteringType_Pdf? "PDF": "PIF" );
            NN_DETAIL_NETDIAG_INFO( "[netdiag]  PortDifference: %d\n", p->portDifference );
        }
    }
#endif

    // 1回目、2回目いずれかが失敗なら 「Z」
    if ( ! stat[0].isSuccess || ! stat[1].isSuccess )
    {
        return NatType_Z;
    }
    // 1回目と2回目のマッピングタイプとフィルタリングタイプが同じ場合
    if ( stat[0].mappingType == stat[1].mappingType && stat[0].filteringType == stat[1].filteringType )
    {
        // マッピングタイプが「EIM」でフィルタリングタイプが「PIF」の場合は、NATタイプ「A」
        if ( stat[0].mappingType == MappingType_Eim && stat[0].filteringType == FilteringType_Pif )
        {
            return NatType_A;
        }

        // マッピングタイプが「EIM」でフィルタリングタイプが「PDF」の場合は、NATタイプ「B」
        if ( stat[0].mappingType == MappingType_Eim && stat[0].filteringType == FilteringType_Pdf )
        {
            return NatType_B;
        }

        // マッピングタイプが「EDM」で1回目と2回目のポート差分が一致している場合は、NATタイプ「C」
        if ( stat[0].mappingType == MappingType_Edm && stat[0].portDifference == stat[1].portDifference )
        {
            return NatType_C;
        }
    }

    // 上記以外はNATタイプ「D」
    return NatType_D;
}

//----------------------------------------------------------------
// NAT タイプ判定を行う
//
nn::Result DetectNatType( NatType* pType ) NN_NOEXCEPT
{
    g_IsInterrupted = false; // 中断フラグ

    // サーバのアドレス情報取得
    NN_RESULT_DO( GetServerAddress() );

    ResponseStat stat[2]; // NATタイプ判定は2回。
    for( int detectCount=0; detectCount<2; detectCount++ )
    {
        NN_DETAIL_NETDIAG_INFO("[netdiag] NAT type detection (TRY: %d)\n", detectCount + 1);

        // 前のポートと重ならないように、2回目は1回目の番号を格納
        uint16_t usedPort = (detectCount == 1)? stat[0].sourcePort: 0;

        int tryCount = 0;
        while( tryCount < RetryTestTime && ! g_IsInterrupted )
        {
            // 今回の統計情報クリア
            ResponseStat* pCurrentStat = &stat[detectCount];
            ClearResponseStat( pCurrentStat );

            // UDP ソケット作成
            NN_RESULT_DO( CreateSocket( pCurrentStat, usedPort ) );
            NN_UTIL_SCOPE_EXIT
            {
                // UDP ソケットクローズ
                CloseSocket();
            };

            // パケット送信
            NN_RESULT_DO( SendDummy() );
            NN_RESULT_DO( SendSimplified1() );
            NN_RESULT_DO( SendSimplified2() );
            NN_RESULT_DO( SendSimplified3() );

            // パケット待ち受け
            NN_RESULT_DO( ReceiveResponse( pCurrentStat ) );

            // マッピング関連(1と3)の結果が揃っていたら抜ける
            if ( (pCurrentStat->receivedFlag & ResponseReceived_MappingSimplified) == ResponseReceived_MappingSimplified )
            {
                break;
            }
            tryCount ++;
        }
    }

    // 中断されているか
    NN_RESULT_THROW_UNLESS( ! g_IsInterrupted, nn::netdiag::ResultInterrupted() );

    // 2回の統計結果を元に判定
    *pType = DetectNatTypeFromStat( stat );

    NN_RESULT_SUCCESS;
}

//----------------------------------------------------------------
// NAT タイプ判定の中断指示を出す
//
void InterruptDetectNatType() NN_NOEXCEPT
{
    g_IsInterrupted = true;
}

}}} // nn::netdiag::detail
