﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/nn_Assert.h>
#include <nn/util/util_FormatString.h>
#include <nn/nn_SdkLog.h>

#include <cstring>
#include <algorithm>

#include "testNet_ApiCommon.h"
#include "Complex/testNet_UnitCommon.h"
#include "Complex/testNet_SelectUnitData.h"

namespace NATF {
namespace API {

/**
 * SelectUnitData
 */
SelectUnitData::SelectUnitData(AddressFamilyFlags addressFamily,
                               SocketTypeFlags socketTypeFlags,
                               ProtocolFlags protocolFlags,
                               SocketControlFlags socketControlFlags,
                               unsigned int numberRead,
                               unsigned int numberWrite,
                               unsigned int numberException,
                               SimpleValidator* validator,
                               uint64_t selectTimeoutInMs) :
    LockedReferenceCountObjectImpl(__FUNCTION__),
    m_AddressFamilyFlags(addressFamily),
    m_SocketTypeFlags(socketTypeFlags),
    m_ProtocolFlags(protocolFlags),
    m_ControlFlags(socketControlFlags),
    m_Port(0),
    m_Validator(validator),
    m_SelectTimeout(selectTimeoutInMs)
{
    //UNIT_TEST_TRACE("");
    NN_ASSERT(NULL != m_Validator);
    int rc = 0;

    nn::os::InitializeMutex(&m_DataAccessLock, true, 0);

    if ( NULL != m_Validator )
    {
        m_Validator->addReference();
    };

    /** SERVER */
    if (SVALIDATE_FAIL(m_Validator,
                       (-1 == (rc = AcquireSockets(m_AddressFamilyFlags,
                                                   m_SocketTypeFlags,
                                                   m_ProtocolFlags,
                                                   SOCKET_CONTAINER_ACTOR_SERVER,
                                                   SOCKET_CONTAINER_TYPE_LISTEN,
                                                   m_ControlFlags,
                                                   numberRead + numberWrite + numberException,
                                                   m_SocketContainerList))),
                       "unable to acquire server read-set sockets"))
    {
        goto bail;
    }
    /** CLIENT */
    else if (SVALIDATE_FAIL(m_Validator,
                            (-1 == AcquireSockets(m_AddressFamilyFlags,
                                                  m_SocketTypeFlags,
                                                  m_ProtocolFlags,
                                                  SOCKET_CONTAINER_ACTOR_CLIENT,
                                                  SOCKET_CONTAINER_TYPE_READ,
                                                  m_ControlFlags,
                                                  numberRead,
                                                  m_SocketContainerList)),
                            "unable to acquire client read-set sockets"))
    {
        goto bail;
    }
    else if (SVALIDATE_FAIL(m_Validator,
                            (-1 == AcquireSockets(m_AddressFamilyFlags,
                                                  m_SocketTypeFlags,
                                                  m_ProtocolFlags,
                                                  SOCKET_CONTAINER_ACTOR_CLIENT,
                                                  SOCKET_CONTAINER_TYPE_WRITE,
                                                  m_ControlFlags,
                                                  numberWrite,
                                                  m_SocketContainerList)),
                            "unable to acquire client write-set sockets"))
    {
        goto bail;
    }
    else if (SVALIDATE_FAIL(m_Validator,
                            (-1 == AcquireSockets(m_AddressFamilyFlags,
                                                  m_SocketTypeFlags,
                                                  m_ProtocolFlags,
                                                  SOCKET_CONTAINER_ACTOR_CLIENT,
                                                  SOCKET_CONTAINER_TYPE_EXCEPTION,
                                                  socketControlFlags,
                                                  numberException,
                                                  m_SocketContainerList)),
                            "unable to acquire client error-set sockets"))
    {
        goto bail;
    };

bail:
    return;
};

SelectUnitData::~SelectUnitData()
{
    //UNIT_TEST_TRACE("");
    if ( NULL != m_Validator )
    {
        m_Validator->releaseReference();
        m_Validator = NULL;
    };

    for (auto iter = m_SocketContainerList.begin(); iter != m_SocketContainerList.end(); ++iter)
    {
        if (-1 == nn::socket::Close(iter->m_Socket))
        {
            UNIT_TEST_TRACE("unable to close socket %d, errno: %d", iter->m_Socket, nn::socket::GetLastError());
            NN_ASSERT(false);
        }
        else
        {
            UNIT_TEST_TRACE("closed socket %d successfully", iter->m_Socket);
        }
    };

    nn::os::FinalizeMutex(&m_DataAccessLock);
};

void SelectUnitData::Cleanup()
{
    //UNIT_TEST_TRACE("");
    // TODO
};

int SelectUnitData::AcquireSockets(AddressFamilyFlags addressFamilyFlags,
                                   SocketTypeFlags socketTypeFlags,
                                   ProtocolFlags protocolFlags,
                                   SocketContainerActor role,
                                   SocketContainerTypeFlags typeFlags,
                                   SocketControlFlags controlFlags,
                                   int number,
                                   std::list<SocketContainer> & socketContainerList)
{
    //UNIT_TEST_TRACE("");

    SocketContainer     socketContainer;
    nn::socket::Family  af = nn::socket::Family::Af_Unspec;
    nn::socket::Type    st = nn::socket::Type::Sock_Default;
    uint32_t            prot = 0,
                        numAcquired = 0;

    switch (addressFamilyFlags)
    {
    case ADDRESS_FAMILY_FLAGS_AF_UNSPEC:
        af = nn::socket::Family::Af_Unspec;
        break;
    case ADDRESS_FAMILY_FLAGS_AF_INET:
        af = nn::socket::Family::Af_Inet;
        break;
    default:
        NN_ASSERT(false);
    };

    switch (socketTypeFlags)
    {
    case SOCK_TYPE_FLAGS_SOCK_STREAM:
        st = nn::socket::Type::Sock_Stream;
        break;
    case SOCK_TYPE_FLAGS_SOCK_DGRAM:
        st = nn::socket::Type::Sock_Dgram;
        break;
    default:
        NN_ASSERT(false);
    };

    switch (protocolFlags)
    {
    case PROTOCOL_FLAGS_DEFAULT:
        prot = 0;
        break;
    default:
        NN_ASSERT(false);
    };

    socketContainer.m_AddressFamilyFlags = addressFamilyFlags;
    socketContainer.m_SocketTypeFlags = socketTypeFlags;
    socketContainer.m_ProtocolFlags = protocolFlags;
    socketContainer.m_SocketControlFlags = controlFlags;
    socketContainer.m_Role = role;
    socketContainer.m_ContainerTypeFlags = typeFlags;

    for (int idx = 0; idx < number; ++idx)
    {

        if (-1 == (socketContainer.m_Socket = socketContainer.m_Socket = nn::socket::Socket(af, st, static_cast<nn::socket::Protocol>(prot))))
        {
            UNIT_TEST_TRACE("%s unable to create socket, errno: %d", SocketContainerActorAsString(socketContainer.m_Role), nn::socket::GetLastError() );
            NN_ASSERT(false);
            goto bail;
        };
        UNIT_TEST_TRACE("acquired %s socket %d", SocketContainerActorAsString(socketContainer.m_Role), socketContainer.m_Socket);

        if (SOCKET_CONTAINER_ACTOR_SERVER == role)
        {
            socketContainer.m_Backlog = number;
            socketContainerList.push_back(socketContainer);
            numAcquired++;
            goto bail;
        }
        else if (SOCKET_CONTAINER_ACTOR_CLIENT == role)
        {
            switch (socketContainer.m_SocketControlFlags)
            {
                // TODO: make this more flags than value
            case SOCKET_CONTROL_FLAGS_DEFAULT:
                break;
            case SOCKET_CONTROL_FLAGS_SOCK_NONBLOCKING:
                //nn::socket::Fcntl(socketContainer.m_Socket, nn::socket::FcntlCommand::F_SetFl, nn::socket::FcntlFlag::O_NonBlock);
                break;
            default:
                UNIT_TEST_TRACE("unhandled default case hit for socket flags", SocketContainerActorAsString(socketContainer.m_Role));
                NN_ASSERT(false);
                goto bail;
            };
            socketContainerList.push_back(socketContainer);
            numAcquired++;
        };
    };

bail:
    return numAcquired;
};

void SelectUnitData::GetSocketContainerListCopy(std::list<SocketContainer>& socketContainerListOut) const
{
    //UNIT_TEST_TRACE("");
    nn::os::LockMutex(&m_DataAccessLock);
    {
        socketContainerListOut = m_SocketContainerList;
    }
    nn::os::UnlockMutex(&m_DataAccessLock);
};

void SelectUnitData::AddSocketContainerToList(const SocketContainer& socketContainerIn)
{
    //UNIT_TEST_TRACE("");
    nn::os::LockMutex(&m_DataAccessLock);
    {
        m_SocketContainerList.push_back(socketContainerIn);
    }
    nn::os::UnlockMutex(&m_DataAccessLock);
}

void SelectUnitData::RemoveSocketContainerFromList(const SocketContainer& socketContainerIn)
{
    //UNIT_TEST_TRACE("");
    nn::os::LockMutex(&m_DataAccessLock);
    {
        for (auto iter = m_SocketContainerList.begin(); iter != m_SocketContainerList.end(); )
        {
            if (*iter == socketContainerIn)
            {
                m_SocketContainerList.remove(*iter);
                break;
            };
        };
    };
    nn::os::UnlockMutex(&m_DataAccessLock);
}


SimpleValidator* SelectUnitData::GetValidator()
{
    //UNIT_TEST_TRACE("");
    SimpleValidator* pSimpleValidator = NULL;
    nn::os::LockMutex(&m_DataAccessLock);
    {
        pSimpleValidator = m_Validator;
        pSimpleValidator->addReference();
    }
    nn::os::UnlockMutex(&m_DataAccessLock);
    return pSimpleValidator;
}

uint16_t SelectUnitData::GetPort() const
{
    //UNIT_TEST_TRACE("");
    uint16_t rc = 0;
    nn::os::LockMutex(&m_DataAccessLock);
    {
        rc = m_Port;
    }
    nn::os::UnlockMutex(&m_DataAccessLock);
    return rc;
}

void SelectUnitData::SetPort(uint16_t port)
{
    //UNIT_TEST_TRACE("");
    nn::os::LockMutex(&m_DataAccessLock);
    {
        m_Port = port;
    }
    nn::os::UnlockMutex(&m_DataAccessLock);
}

void SelectUnitData::GetSelectTimeout(nn::socket::TimeVal *& pSelectTimeoutInOut) const
{
    uint64_t selectTimeout = 0;
    if (pSelectTimeoutInOut == NULL)
    {
        goto bail;
    }

    nn::os::LockMutex(&m_DataAccessLock);
    {
        selectTimeout = m_SelectTimeout;

    }
    nn::os::UnlockMutex(&m_DataAccessLock);

    if (selectTimeout == 0)
    {
        pSelectTimeoutInOut = NULL;
        goto bail;
    }

    pSelectTimeoutInOut->tv_sec = static_cast<long>(selectTimeout / 1000);
    pSelectTimeoutInOut->tv_usec = selectTimeout % 1000 * 1000;

bail:
    return;
};

}}; // namespace NATF::API
