﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/nn_Assert.h>
#include <nn/util/util_FormatString.h>
#include <nn/nn_SdkLog.h>
#include <nn/socket.h>

#include <cstring>
#include <algorithm> //std::max

#include "testNet_ApiCommon.h"
#include "Complex/testNet_UnitCommon.h"
#include "Complex/testNet_SelectUnitData.h"
#include "Complex/testNet_SelectUnitNetworkCommon.h"

namespace NATF {
namespace API {

/**
 * @brief select unit client
 */
SelectUnitNetworkCommon::SelectUnitNetworkCommon(SelectUnitData* pSelectUnitData, SocketContainerActor role, UnitTestThreadBase* pSelfThread) :
    m_pSelectUnitData(pSelectUnitData),
    m_pValidator(NULL),
    m_Role(role),
    m_pSelfThread(pSelfThread),
    m_pPeerThread(NULL)
{
    //UNIT_TEST_TRACE("");
    NN_ASSERT(m_pSelectUnitData != NULL);
    m_pSelectUnitData->addReference();

    m_pValidator = m_pSelectUnitData->GetValidator();
    NN_ASSERT(m_pValidator != NULL);
    m_pValidator->addReference();

    NN_ASSERT(m_pSelfThread != NULL);
    nn::os::InitializeMutex(&m_CommonAccessLock, true, 0);
};

void SelectUnitNetworkCommon::SetPeer(UnitTestThreadBase* peer)
{
    if (peer == m_pPeerThread)
    {
        goto bail;
    }
    else if (NULL != m_pPeerThread)
    {
        m_pPeerThread->releaseReference();
        m_pPeerThread = NULL;
    };

    if (NULL != (m_pPeerThread = peer))
    {
        m_pPeerThread->addReference();
    };

bail:
    return;
};

SelectUnitNetworkCommon::~SelectUnitNetworkCommon()
{
    //UNIT_TEST_TRACE("");
    nn::os::FinalizeMutex(&m_CommonAccessLock);

    // release threads when run is finished
    m_pValidator->releaseReference();
    m_pSelectUnitData->releaseReference();
};

void SelectUnitNetworkCommon::RunLoop()
{
    nn::socket::TimeVal timeValue;
    nn::socket::TimeVal* pTimeValue;
    int rc, maxfd = -1;
    unsigned int originalEventsHandled = 0, eventsHandled = 0;
    nn::socket::FdSet readfds, writefds, exceptionfds;
    nn::socket::FdSet *pReadFds, *pWriteFds, *pExceptionFds;
    IUnitTestThread::State currentState;
    unsigned int currentRetries = 0;

    if (NULL == m_pPeerThread)
    {
        SVALIDATE_FAIL(m_pValidator, true, "%s peer thread not set, bailing.", SocketContainerActorAsString(m_Role));
        goto bail;
    };

    for (;;)
    {
        nn::socket::FdSetZero(&readfds); pReadFds = &readfds;
        nn::socket::FdSetZero(&writefds); pWriteFds = &writefds;
        nn::socket::FdSetZero(&exceptionfds); pExceptionFds = &exceptionfds;
        pTimeValue = &timeValue;

        if (m_pValidator->DidFail() == true)
        {
            UNIT_TEST_TRACE("%s bailed because validator failed", SocketContainerActorAsString(m_Role));
            goto bail;
        }
        else if (m_SocketContainerList.size() == 0)
        {
            UNIT_TEST_TRACE("%s bailed because socketContainerList.size() = %d", SocketContainerActorAsString(m_Role), m_SocketContainerList.size());
            goto bail;
        };

        if (SOCKET_CONTAINER_ACTOR_CLIENT == m_Role)
        {
            GetSocketContainerTypeCountAndSet(static_cast<SocketContainerTypeFlags>(SOCKET_CONTAINER_TYPE_READ), pReadFds);
        }
        else
        {
            GetSocketContainerTypeCountAndSet(static_cast<SocketContainerTypeFlags>(SOCKET_CONTAINER_TYPE_LISTEN | SOCKET_CONTAINER_TYPE_READ | SOCKET_CONTAINER_TYPE_EMBRYONIC), pReadFds);
        };

        GetSocketContainerTypeCountAndSet(SOCKET_CONTAINER_TYPE_WRITE, pWriteFds);
        GetSocketContainerTypeCountAndSet(SOCKET_CONTAINER_TYPE_EXCEPTION, pExceptionFds);
        m_pSelectUnitData->GetSelectTimeout(pTimeValue);

        maxfd = GetMaxFd();
        if (SVALIDATE_FAIL(m_pValidator, (-1 == maxfd), "%s GetMaxFd returned: %d", SocketContainerActorAsString(m_Role), maxfd))
        {
            goto bail;
        }
        else if (SVALIDATE_FAIL(m_pValidator,
                                (NULL == pReadFds && NULL == pWriteFds && NULL == pExceptionFds && 0 != m_SocketContainerList.size()),
                                "%s all fd sets null but socketContainerList size: %d, errno: %d", SocketContainerActorAsString(m_Role), m_SocketContainerList.size(), nn::socket::GetLastError()))
        {

            goto bail;
        }
        else  if (-1 == (rc = nn::socket::Select(maxfd + 1, pReadFds, pWriteFds, pExceptionFds, pTimeValue)))
        {
            nn::socket::Errno errorNumber = nn::socket::GetLastError();
            if (errorNumber == nn::socket::Errno::EAgain || errorNumber == nn::socket::Errno::EWouldBlock || errorNumber == nn::socket::Errno::EIntr)
            {
                UNIT_TEST_TRACE("%s select rv=%d, errno: %d; socketContainerList size: %d, continuing.", SocketContainerActorAsString(m_Role), rc, errorNumber, m_SocketContainerList.size());
                continue;
            }
            else
            {
                SVALIDATE_FAIL(m_pValidator, true, "%s select rv=%d, errno: %d; socketContainerList size: %d, bailing", SocketContainerActorAsString(m_Role), rc, errorNumber, m_SocketContainerList.size());
                goto bail;
            };
        }
        else if (IUnitTestThread::STATE_FINISHING == (currentState = m_pSelfThread->GetState()))
        {
            // not an error
            UNIT_TEST_TRACE("%s bailed because own state advanced to FINISHING", SocketContainerActorAsString(m_Role));
            goto bail;
        }
        else if (IUnitTestThread::STATE_FINISHING == (currentState = m_pPeerThread->GetState())) //TODO
        {
            // not an error
            UNIT_TEST_TRACE("%s bailed because peer state advanced to FINISHING", SocketContainerActorAsString(m_Role));
            goto bail;
        }
        else if (rc == 0)
        {
            // TODO: check total timeout, not currently tracked
            UNIT_TEST_TRACE("%s select t/o rv=%d, errno: %d, socketContainerList size: %d", SocketContainerActorAsString(m_Role), rc, nn::socket::GetLastError(), m_SocketContainerList.size());
            continue;
        }
        else
        {
            originalEventsHandled = eventsHandled = rc;

            eventsHandled -= DispatchEvents(pReadFds, pWriteFds, pExceptionFds);

            if (0 != eventsHandled && ++currentRetries > 10)
            {
                if (SVALIDATE_FAIL(m_pValidator,
                                   (true),
                                   "%s did not handle all events, original: %d, left: %d", SocketContainerActorAsString(m_Role),
                                   originalEventsHandled, eventsHandled))
                {
                    goto bail;
                };
            };
            UNIT_TEST_TRACE("%s %d/%d events left (retries left: %d/10", SocketContainerActorAsString(m_Role), eventsHandled, originalEventsHandled, currentRetries);
        };
    };

bail:
    if (NULL != m_pPeerThread)
    {
        m_pPeerThread->releaseReference();
        m_pPeerThread = m_pSelfThread = NULL;
    };
    return;
}; //NOLINT(impl/function_size)

void SelectUnitNetworkCommon::GetSiftedList()
{
    std::list<SocketContainer> socketContainerListCopy;
    m_pSelectUnitData->GetSocketContainerListCopy(socketContainerListCopy);

    m_SocketContainerList.empty();

    for (auto iter = socketContainerListCopy.begin(); iter != socketContainerListCopy.end(); ++iter)
    {
        if (0 != (iter->m_ContainerTypeFlags & SOCKET_CONTAINER_TYPE_LISTEN)
            && iter->m_Backlog == 0)
        {
            // listening sockets with nothing to wait for should not be added to the nn::socket::FdSetSet
            continue;
        }
        else if (iter->m_Role == m_Role)
        {
            m_SocketContainerList.push_back(*iter);
        };
    };
}

ssize_t SelectUnitNetworkCommon::ReceiveBytes(const SocketContainer& container, uint8_t* pBufferPointer, size_t count, nn::socket::MsgFlag flags)
{
    ssize_t rc = -1;
    size_t nread = 0;

    while (nread < count)
    {
        rc = nn::socket::Recv(container.m_Socket, pBufferPointer + nread, count - nread, flags);
        nn::socket::Errno errorNumber = nn::socket::GetLastError();
        if ((-1 == rc) && (nn::socket::Errno::EAgain == errorNumber || nn::socket::Errno::EWouldBlock == errorNumber))
        {
            continue;
        }
        else if (-1 == rc)
        {
            goto bail;
        }
        else if (0 == rc)
        {
            goto bail;
        }
        else
        {
            nread += rc;
        };
    };

    rc = nread;

bail:
    return rc;
}

ssize_t SelectUnitNetworkCommon::SendBytes(SocketContainer& container, const uint8_t* pBufferPointer, size_t count, nn::socket::MsgFlag flags)
{
    ssize_t rc = -1;
    size_t nwritten = 0;

    while (nwritten < count)
    {
        rc = nn::socket::Send(container.m_Socket, pBufferPointer + nwritten, count - nwritten, flags);
        nn::socket::Errno errorNumber = nn::socket::GetLastError();
        if ((-1 == rc) && (nn::socket::Errno::EAgain == errorNumber || nn::socket::Errno::EWouldBlock == errorNumber))
        {
            continue;
        }
        else if (-1 == rc)
        {
            goto bail;
        }
        else if (0 == rc)
        {
            goto bail;
        }
        else
        {
            nwritten += rc;
        };
    };

    rc = nwritten;

bail:
    return rc;
}

int SelectUnitNetworkCommon::DispatchEvents(nn::socket::FdSet* pReadFds, nn::socket::FdSet* pWriteFds, nn::socket::FdSet* pExceptionFds)
{
    //UNIT_TEST_TRACE("");

    int rc = 0;

    rc += DispatchEventsInternal(pReadFds, SOCKET_CONTAINER_TYPE_READ);
    if (true == m_pValidator->DidFail())
    {
        goto bail;
    };

    rc += DispatchEventsInternal(pWriteFds, SOCKET_CONTAINER_TYPE_WRITE);
    if (true == m_pValidator->DidFail())
    {
        goto bail;
    };

    rc += DispatchEventsInternal(pExceptionFds, SOCKET_CONTAINER_TYPE_EXCEPTION);
    if (true == m_pValidator->DidFail())
    {
        goto bail;
    };

    for (auto iter = m_ShutdownQueue.begin(); iter != m_ShutdownQueue.end(); ++iter)
    {
        auto internalListIter = std::find(m_SocketContainerList.begin(), m_SocketContainerList.end(), *iter);
        if ( internalListIter != m_SocketContainerList.end() )
        {
            ShutdownSocket(internalListIter);
        };
    };
    m_ShutdownQueue.clear();

    for (auto iter = m_WaitingQueue.begin(); iter != m_WaitingQueue.end(); ++iter)
    {
        UNIT_TEST_TRACE("adding socket: %d to socket queue", iter->m_Socket);
        m_SocketContainerList.push_back(*iter);
    };
    m_WaitingQueue.clear();

bail:
    return rc;
};

int SelectUnitNetworkCommon::DispatchEventsInternal(nn::socket::FdSet* pFileDescriptorSet, SocketContainerTypeFlags events)
{
    //UNIT_TEST_TRACE("");

    int eventsHandled = 0;

    if (NULL == pFileDescriptorSet)
    {
        goto bail;
    };

    for (auto iter = m_SocketContainerList.begin(); iter != m_SocketContainerList.end(); ++iter)
    {
        if (nn::socket::FdSetIsSet(iter->m_Socket, pFileDescriptorSet))
        {

            // the container is a listening socket and we got a read event from select
            if (0 != (iter->m_ContainerTypeFlags & SOCKET_CONTAINER_TYPE_LISTEN) && 0 != (events & SOCKET_CONTAINER_TYPE_READ))
            {
                eventsHandled += OnNewConnectionEvent(*iter);
            }
            // the container is an embryonic socket and we got a read event from select
            else if (0 != (iter->m_ContainerTypeFlags & SOCKET_CONTAINER_TYPE_EMBRYONIC) && 0 != (events & SOCKET_CONTAINER_TYPE_READ))
            {
                eventsHandled += OnEmbryonicReadEvent(*iter);
            }
            // the container is an read-type socket and we got a read event from select
            else if (0 != (iter->m_ContainerTypeFlags & SOCKET_CONTAINER_TYPE_READ) && 0 != (events & SOCKET_CONTAINER_TYPE_READ))
            {
                eventsHandled += OnReadEvent(*iter);
            }
            // the container is a write-type socket and we got a write event from select
            else if (0 != (iter->m_ContainerTypeFlags & SOCKET_CONTAINER_TYPE_WRITE) && 0 != (events & SOCKET_CONTAINER_TYPE_WRITE))
            {
                eventsHandled += OnWriteEvent(*iter);
            }
            // the container is an exception-type socket and we got an exception event from select
            else if (0 != (iter->m_ContainerTypeFlags & SOCKET_CONTAINER_TYPE_EXCEPTION) && 0 != (events & SOCKET_CONTAINER_TYPE_EXCEPTION))
            {
                eventsHandled += OnExceptionEvent(*iter);
            };

            if (true == m_pValidator->DidFail())
            {
                goto bail;
            }

        }
    };

    for (auto iter = m_SocketContainerList.begin(); iter != m_SocketContainerList.end(); ++iter)
    {
        if (iter->m_CurrentState >= SOCKET_CONTAINER_TYPE_SHUTDOWN)
        {
            m_ShutdownQueue.push_back(*iter);
        };
    };

bail:
    return eventsHandled;
}

unsigned int SelectUnitNetworkCommon::GetSocketContainerTypeCountAndSet(SocketContainerTypeFlags socketContainerTypeFlags, nn::socket::FdSet* & pFdSetInOut) const
{
    nn::socket::FdSetZero(pFdSetInOut);
    unsigned int rc = 0;

    nn::os::LockMutex(&m_CommonAccessLock);
    {
        for (auto iter = m_SocketContainerList.begin(); iter != m_SocketContainerList.end(); ++iter)
        {
            if ((SOCKET_CONTAINER_ACTOR_EMBRYONIC == iter->m_Role
                 || iter->m_Role == m_Role)
                && 0 != (iter->m_ContainerTypeFlags & socketContainerTypeFlags))
            {
                ++rc;
                nn::socket::FdSetSet(iter->m_Socket, pFdSetInOut);
            }
        }
        nn::os::UnlockMutex(&m_CommonAccessLock);
        if (rc == 0) pFdSetInOut = NULL;
    }
    return rc;
}

int SelectUnitNetworkCommon::GetMaxFd()
{
    int n = -1;
    for (auto iter = m_SocketContainerList.begin(); iter != m_SocketContainerList.end(); ++iter)
    {
        n = std::max<int>(n, iter->m_Socket);
    };
    return n;
}

bool SelectUnitNetworkCommon::ShutdownSocket(std::list<SocketContainer>::iterator & iteratorInOut)
{
    bool shouldContinue = false;
    int rc = -1;

    if (0 != (iteratorInOut->m_ContainerTypeFlags & SOCKET_CONTAINER_TYPE_LISTEN))
    {
        UNIT_TEST_TRACE("Unable to shutdown listener socket %d, calling listen with zero.", iteratorInOut->m_Socket);
        rc = nn::socket::Listen(iteratorInOut->m_Socket, 0);
        if (-1 == rc)
        {
            UNIT_TEST_TRACE("%s shutdown failed on socket %d, errno: %d", SocketContainerActorAsString(m_Role),
                            iteratorInOut->m_Socket,
                            nn::socket::GetLastError());
            NN_ASSERT(false);
            goto bail;
        };

    }
    else
    {
        rc = nn::socket::Shutdown(iteratorInOut->m_Socket, nn::socket::ShutdownMethod::Shut_RdWr);
        if (-1 == rc)
        {
            SVALIDATE_FAIL(m_pValidator, true, "%s shutdown failed on socket %d, errno: %d", SocketContainerActorAsString(m_Role), iteratorInOut->m_Socket, nn::socket::GetLastError());
            goto bail;
        };
        UNIT_TEST_TRACE("%s successfully shutdown socket %d", SocketContainerActorAsString(m_Role), iteratorInOut->m_Socket);
    };

    m_SocketContainerList.remove(*iteratorInOut++);
    if (iteratorInOut == m_SocketContainerList.end())
    {
        goto bail;
    };

    shouldContinue = true;

bail:
    return shouldContinue;
};


void SelectUnitNetworkCommon::AddSocketContainerToWaitingQueue(SocketContainer& newContainer)
{
    UNIT_TEST_TRACE("%s adding socket: %d to waiting queue", SocketContainerActorAsString(m_Role), newContainer.m_Socket);
    m_WaitingQueue.push_back(newContainer);
    m_pSelectUnitData->AddSocketContainerToList(newContainer);
};


ssize_t SelectUnitNetworkCommon::SendContainer(SocketContainer container, const char* eventName, nn::socket::MsgFlag flags)
{
    //UNIT_TEST_TRACE("");
    ssize_t rc = -1;
    uint8_t buf[1024] = { '\0' };
    uint8_t* pBuffer = buf;
    nn::socket::Errno errorNumber = nn::socket::Errno::ESuccess;

    rc = container.ToNetworkBuffer(pBuffer, sizeof(buf));
    if (SVALIDATE_FAIL(m_pValidator, (-1 == rc), "%s: toNetworkBuffer failed", eventName))
    {
        container.m_CurrentState = static_cast<SocketContainerTypeFlags>(container.m_CurrentState | SOCKET_CONTAINER_TYPE_ERROR);
        goto bail;
    };

    rc = SendBytes(container, buf, container.SizeOf(), flags);
    errorNumber = nn::socket::GetLastError();
    if (0 == rc
        || (nn::socket::Errno::EPipe == errorNumber && -1 == rc)
        || (nn::socket::Errno::EConnAborted == errorNumber && -1 == rc) // siglo for windows
        || (nn::socket::Errno::EConnReset == errorNumber && -1 == rc))  // siglo for windows
    {
        container.m_CurrentState = static_cast<SocketContainerTypeFlags>(container.m_CurrentState | SOCKET_CONTAINER_TYPE_SHUTDOWN);
        goto bail;
    }
    else if (-1 == rc)
    {
        SVALIDATE_FAIL(m_pValidator, true, "%s: SendBytes failed, socket: %d, rc: %d, errno: %d", eventName, container.m_Socket, rc, errorNumber);
        container.m_CurrentState = static_cast<SocketContainerTypeFlags>(container.m_CurrentState | SOCKET_CONTAINER_TYPE_ERROR);
        goto bail;
    }
    else
    {
        UNIT_TEST_TRACE("%s: sent %d bytes on socket %d", eventName, rc, container.m_Socket);
    };

bail:
    return rc;
}


ssize_t SelectUnitNetworkCommon::ReceiveContainer(SocketContainer& containerTo, SocketContainer& containerFrom, const char* eventName, nn::socket::MsgFlag flags)
{
    //UNIT_TEST_TRACE("");
    ssize_t rc = -1;
    uint8_t buf[1024] = { '\0' };
    uint8_t* pBuffer = buf;
    nn::socket::Errno errorNumber = nn::socket::Errno::ESuccess;

    rc = ReceiveBytes(containerFrom, buf, SocketContainer::SizeOf(), nn::socket::MsgFlag::Msg_None);
    errorNumber = nn::socket::GetLastError();

    if (0 == rc
        || (nn::socket::Errno::EPipe == errorNumber && -1 == rc)
        || (nn::socket::Errno::EConnAborted == errorNumber && -1 == rc) // siglo for windows
        || (nn::socket::Errno::EConnReset == errorNumber && -1 == rc))  // siglo for windows
    {
        UNIT_TEST_TRACE("%s: other side closed the socket %d, error: %d", eventName, containerFrom.m_Socket, errorNumber);
        containerFrom.m_CurrentState = SOCKET_CONTAINER_TYPE_SHUTDOWN;
        goto bail;
    }
    else if (-1 == rc)
    {
        SVALIDATE_FAIL(m_pValidator, true, "%s: OOB recv failed: %d, socket: %d, errno: %d", eventName, rc, containerFrom.m_Socket, errorNumber);
        containerFrom.m_CurrentState = SOCKET_CONTAINER_TYPE_ERROR;
        goto bail;
    }
    else
    {
        UNIT_TEST_TRACE("%s %d bytes, socket: %d", eventName, rc, containerFrom.m_Socket, containerFrom.m_ContainerTypeFlags);
    }

    if (SVALIDATE_FAIL(m_pValidator, (0 > SocketContainer::FromNetworkBuffer(containerTo, pBuffer, rc)), "fromNetworkBuffer failed"))
    {
        rc = -1;
        goto bail;
    };

bail:
    return rc;
}

}}; // NATF::API
