﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#pragma once

#include "Port/Port.h"

#include <cstdio>
#include <cstring>

namespace nnt { namespace net { namespace p2p {

static const uint32_t MaxNodes = 32; // Max number of nodes.

static const int      SockRecvBufLen       = 1024 * 64; // Internal receive buffer length.
static const uint32_t StackLen             = 1024 * 16; // Stack length for send and receive threads.
static const uint32_t MaxPacketLen         = 2 * 1024;  // Max packet length.
static const uint32_t FrameTimeoutMS       = 500;       // Timeout for threads to sync for each frame.
static const uint32_t InitThreadTimeoutSec = 4;         // Timeout for send/recv threads to initialize.
static const uint32_t ThreadExitTimeoutSec = 10;        // How long to wait for thread to exit after sending shutdown command.
static const uint32_t MaxSamples           = 1024 * 32; // Max number of samples to store when calculating stats.
static const uint32_t IpBufLen             = 16;        // Max ip example: 255.255.255.255 plus null terminator.
static const float    RecvFrameReserveMs   = 0.1f;      // How much time to reserve for the recv thread to sync with the frame;
static const uint32_t PrintBufferLen       = 10 * 1024 * 1024;

#define P2P_LOG(format, ...) PacketRateTest::g_Logger.Log("[NetTest] " format, ##__VA_ARGS__)
#define CSV_LOG(format, ...) PacketRateTest::g_CsvLogger.Log(format, ##__VA_ARGS__)

// Parameters for print thread
class PrintParam
{
public:
    PrintParam() NN_NOEXCEPT :
        doShutdown(false)
    {
        NetTest::InitEvent(&isExitedEvent);
    }

    ~PrintParam() NN_NOEXCEPT
    {
        NetTest::FinalizeEvent(&isExitedEvent);
    }

    bool doShutdown;
    NetTest::Thread thread;
    NetTest::Event isExitedEvent;
};

// Logs text to a buffer and only prints it when Flush() is explicitly called.
class BufferedLogger
{
public:
    BufferedLogger(char* pBuffer, uint32_t bufferLen) NN_NOEXCEPT;

    void Log(const char* format, ...) NN_NOEXCEPT;
    void Flush(uint32_t charCount) NN_NOEXCEPT;

private:

    char* const m_pBuffer;
    const uint32_t m_BufferLen;
    uint32_t m_PrintPos;
    uint32_t m_FlushPos;
    bool m_IsOutOfSpace;
};

struct PeerAddr
{
    char pIp[IpBufLen];
    uint16_t port;
};

enum RecvError
{
    RecvError_None      = 0,
    RecvError_Ready     = 1,
    RecvError_NotReady  = 2,
    RecvError_Error     = 3
};

// Command line arguments are parsed and stored in this class.
class Params
{
public:

    Params() NN_NOEXCEPT;
    bool ParseCommandLine(int argC, const char * const * argV) NN_NOEXCEPT;

    PeerAddr pAddrs[MaxNodes];
    uint8_t nodeCount;
    uint8_t nodeIndex;
    uint32_t packetSize;
    uint32_t packetsPerFrame;
    float frameIntervalMs;
    uint32_t reportIntervalMs;
    uint32_t testDurationSec;
    uint32_t rttFrameInterval;
    bool reportCallTime;
    bool doVerifyData;
    bool doTermOnBadData;
    bool reportNodeStats;
    bool reportCsv;
    bool noReport;
};

class GlobalStats
{
public:

    GlobalStats() NN_NOEXCEPT :
        frameCount(0),
        minSendMicroSec(0xFFFFFFFF),
        maxSendMicroSec(0),
        sendCount(0),
        sendDurationSum(0),
        minRecvMicroSec(0xFFFFFFFF),
        maxRecvMicroSec(0),
        recvCount(0),
        recvDurationSum(0),
        frameDurationSumMs(0.0f)
    {}

    void Clear() NN_NOEXCEPT
    {
        frameCount = 0;

        minSendMicroSec = 0xFFFFFFFF;
        maxSendMicroSec = 0;
        sendCount       = 0;
        sendDurationSum = 0;

        minRecvMicroSec = 0xFFFFFFFF;
        maxRecvMicroSec = 0;
        recvCount       = 0;
        recvDurationSum = 0;

        frameDurationSumMs = 0;
    }

    uint32_t frameCount;    // Number of frames that passed within stats duration.

    uint32_t minSendMicroSec;
    uint32_t maxSendMicroSec;
    uint32_t sendCount;
    uint32_t sendDurationSum;

    uint32_t minRecvMicroSec;
    uint32_t maxRecvMicroSec;
    uint32_t recvCount;
    uint32_t recvDurationSum;

    float frameDurationSumMs;

    uint32_t pSendSamples[MaxSamples];
    uint32_t pRecvSamples[MaxSamples];

    uint32_t rttFrameCounter; // Number of frames since we last sent an RTT packet.
};

class PeerStats
{
public:

    PeerStats() NN_NOEXCEPT : packetsSent(0),
              packetsRecved(0),
              packetsOutOfRange(0),
              packetsDropped(0),
              packetsBadData(0),
              rttStart(0),
              totalRttTicks(0),
              rttSeqNum(0),
              ackRecvCount(0) {}

    void Print(uint32_t statsDurationMS, uint32_t frameCount) NN_NOEXCEPT;
    void Clear() NN_NOEXCEPT;

    uint32_t frameSent;
    uint32_t frameRecvd;
    uint32_t packetsSent;
    uint32_t packetsRecved;
    int32_t  packetsOutOfRange;
    int32_t  packetsDropped;
    uint32_t packetsBadData;
    NetTest::Tick rttStart;
    NetTest::Tick totalRttTicks;
    uint32_t rttSeqNum;
    uint32_t ackRecvCount;
};

struct NodeInfo
{
    uint32_t nodeIndex;
    uint32_t baseSequenceNumber;
    uint32_t highSeqNum;
    uint32_t framePacketCount;
    nn::socket::SockAddrIn addr;
    PeerStats packetStats;
};

#pragma pack(push, 1)

enum class NetHeaderFlags : uint8_t
{
    None = 0,
    Rtt  = 1,
    Ack  = 2
};

struct NetHeader
{
    uint8_t        headerLen;
    uint8_t        nodeIndex;
    uint16_t       checksum;
    uint32_t       sequenceNum;
    NetHeaderFlags flags;
};

class NetMessage
{
private:
    NetHeader m_Header;

    uint16_t CalcCheckSum(void* pData, uint32_t dataLen) NN_NOEXCEPT;

public:

    bool VerifyCheckSum(uint32_t packetLen) NN_NOEXCEPT;
    bool ValidateHeader(uint32_t packetLen, const nn::socket::SockAddrIn& remoteAddr) NN_NOEXCEPT;
    void SetHeader(uint8_t nodeIndex, uint32_t sequenceNum, NetHeaderFlags flags, uint32_t packetSize) NN_NOEXCEPT;

    uint8_t  GetHeaderLen()   const NN_NOEXCEPT { return m_Header.headerLen; }
    uint8_t  GetNodeIndex()   const NN_NOEXCEPT { return m_Header.nodeIndex; }
    uint16_t GetChecksum()    const NN_NOEXCEPT { return NetTest::Ntohs(m_Header.checksum); }
    uint32_t GetSequenceNum() const NN_NOEXCEPT { return NetTest::Ntohl(m_Header.sequenceNum); }
    NetHeaderFlags GetFlags() const NN_NOEXCEPT { return m_Header.flags; }

    uint8_t payload[MaxPacketLen - sizeof(NetHeader)];
};
#pragma pack(pop)

enum ThreadError
{
    ThreadError_None       = 0,
    ThreadError_Connection = 1,
    ThreadError_Send       = 2,
    ThreadError_Recv       = 3,
    ThreadError_Unknown    = 4,
};

// Parameter for both the send and recv threads.
class ThreadData
{
public:

    ThreadData() NN_NOEXCEPT : doShutdown(false), isThreadInit(false), error(ThreadError_None)
    {
        NetTest::InitEvent(&isReady);
        NetTest::InitEvent(&doFrame);
        NetTest::InitEvent(&isExited);
    }

    ~ThreadData() NN_NOEXCEPT
    {
        NetTest::FinalizeEvent(&isReady);
        NetTest::FinalizeEvent(&doFrame);
        NetTest::FinalizeEvent(&isExited);

        if( isThreadInit )
        {
            NetTest::DestroyThread(&hThread);
            isThreadInit = false;
        }
    }

    NetTest::Event isReady;
    NetTest::Event doFrame;
    NetTest::Event isExited;
    NetTest::Thread hThread;

    bool doShutdown;
    bool isThreadInit;
    ThreadError error;
};

// Main test class
class PacketRateTest
{
private:
    PacketRateTest() NN_NOEXCEPT
    : m_PeerCount(0) {}

    bool RunTestPrivate(int argC, const char * const * argV) NN_NOEXCEPT;
    bool Init() NN_NOEXCEPT;
    void CleanUp() NN_NOEXCEPT;
    bool ProcessPacket(const NetMessage& packet, int recvSocket, const NetTest::Tick& recvdTick) NN_NOEXCEPT;
    int InitRecvSocket() NN_NOEXCEPT;
    RecvError CheckRecvError(int recvSocket, ssize_t recvLen, NetTest::Tick startFrame) NN_NOEXCEPT;

    NodeInfo& GetPeerByHandle(uint32_t peerHandle) NN_NOEXCEPT;
    void PrintStats(uint32_t statsDurationMs, uint32_t frameDurationMs) NN_NOEXCEPT;
    uint32_t GetRandom() NN_NOEXCEPT;

    // ******************** Friends ********************
    friend NETTEST_DECLARE_THREAD_FUNC(RecvThread, pArg);
    friend NETTEST_DECLARE_THREAD_FUNC(SendThread, pArg);
    friend class NetMessage;
    friend class CallTimeParser;

    static ssize_t TimedRecvFrom    (int socket, NetMessage& message, nn::socket::SockAddrIn* pAddr, nn::socket::SockLenT* pAddrLen, NetTest::Tick& recvdTick) NN_NOEXCEPT;
    static ssize_t NormalRecvFrom   (int socket, NetMessage& message, nn::socket::SockAddrIn* pAddr, nn::socket::SockLenT* pAddrLen, NetTest::Tick& recvdTick) NN_NOEXCEPT;
    static ssize_t TimedSendTo      (int socket, NetMessage& message, NodeInfo& peer, uint32_t packetSize) NN_NOEXCEPT;
    static ssize_t NormalSendTo     (int socket, NetMessage& message, NodeInfo& peer, uint32_t packetSize) NN_NOEXCEPT;

    static ssize_t (*g_pRecvFromFn) (int socket, NetMessage& message, nn::socket::SockAddrIn* pAddr, nn::socket::SockLenT* pAddrLen, NetTest::Tick& recvdTick) NN_NOEXCEPT;
    static ssize_t (*g_pSendToFn)   (int socket, NetMessage& message, NodeInfo& peer, uint32_t packetSize) NN_NOEXCEPT;

public:

    static bool RunTest(int argC, const char * const * argV) NN_NOEXCEPT;
    static const Params& GetParams() NN_NOEXCEPT
    { return g_Params; }

    static BufferedLogger g_Logger;
    static BufferedLogger g_CsvLogger;

private:

    static bool SendRttToPeers(PacketRateTest* pTest, ThreadData& params, int sendSocket, uint32_t sequenceNumber, uint32_t sentPackets) NN_NOEXCEPT;
    static bool SendToPeers(PacketRateTest* pTest, ThreadData& params, int sendSocket, uint32_t sequenceNumber, uint32_t sentPackets) NN_NOEXCEPT;

    static Params g_Params; // Parsed params
    static GlobalStats g_Stats;

    static NodeInfo g_pNodes[MaxNodes];       // All nodes including self.
    static uint32_t g_pPeerIndeces[MaxNodes]; // Maps 'peer' index to 'node' index.

    static char g_pPrintBuffer[PrintBufferLen];
    static char g_pCsvBuffer[PrintBufferLen];

    static NN_ALIGNAS(4096) unsigned char g_pRecvThreadStack[StackLen];
    static NN_ALIGNAS(4096) unsigned char g_pSendThreadStack[StackLen];
    static NN_ALIGNAS(4096) unsigned char g_pPrintThreadStack[StackLen];

    static NetMessage g_RecvNetMessage;

    uint32_t m_PeerCount;

    ThreadData m_RecvThreadData;
    ThreadData m_SendThreadData;
    PrintParam m_PrintParam;

    NetTest::Tick m_TestStart;
    NetTest::Tick m_StatsStart;
};

NETTEST_DECLARE_THREAD_FUNC(RecvThread, pParam);
NETTEST_DECLARE_THREAD_FUNC(SendThread, pParam);

}}} // Namespaces
