﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/

#include <nn/os.h>
#include <nn/nn_Log.h>
#include <nn/socket.h>
#include <nnt/nntest.h>

namespace NATF {
namespace API {

NN_ALIGNAS(4096) uint8_t g_GetAndSetNonblockSimpleSocketMemoryPoolBuffer[nn::socket::DefaultSocketMemoryPoolSize];

void testGetAndSetNonblockSimple(uint32_t numberOfTimes=1)
{
    nn::Result result;
    int rc = -1;
    int sockfd = -1;
    int initialFlags = -1;
    int flags = -1;

    // retain result to finalize the socket layer below
    result = nn::socket::Initialize(g_GetAndSetNonblockSimpleSocketMemoryPoolBuffer,
                                    nn::socket::DefaultSocketMemoryPoolSize,
                                    nn::socket::MinSocketAllocatorSize,
                                    nn::socket::DefaultConcurrencyLimit);

    if( result.IsFailure() )
    {
        NN_LOG("Error: nn::socket::Initialize() failed. Err Desc: %d\n", result.GetDescription());
        ADD_FAILURE();
        goto bail;
    }
    // create the socket
    else if (-1 == (sockfd = nn::socket::Socket(nn::socket::Family::Af_Inet, nn::socket::Type::Sock_Stream, nn::socket::Protocol::IpProto_Tcp)))
    {
        NN_LOG("Unable to get socket, rc: %d, errno: %d\n", rc, sockfd, nn::socket::GetLastError());
        ADD_FAILURE();
        goto bail;
    };
    for (unsigned int idx=0; idx<numberOfTimes; ++idx)
    {
        // GET default socket flags and check that nn::socket::FcntlFlag::O_NonBlock is not set (default expected case)
        if (-1 == (initialFlags = nn::socket::Fcntl(sockfd, nn::socket::FcntlCommand::F_GetFl, 0)))
        {
            NN_LOG("Error: Fcntl(%d, nn::socket::FcntlCommand::F_GetFl, 0) returned -1, errno: %d\n", sockfd, nn::socket::GetLastError());
            ADD_FAILURE();
            goto bail;
        }
        else if ( 0 != ( static_cast<int>(nn::socket::FcntlFlag::O_NonBlock) & initialFlags))
        {
            NN_LOG("Error: initial flags (%x) contain nn::socket::FcntlFlag::O_NonBlock, but it shouldn't yet.\n", rc);
            ADD_FAILURE();
            goto bail;
        };

        flags = initialFlags | static_cast<int>(nn::socket::FcntlFlag::O_NonBlock);
        if ( -1 == (rc = nn::socket::Fcntl(sockfd, nn::socket::FcntlCommand::F_SetFl, flags)))
        {
            NN_LOG("Error: Fcntl(%d, nn::socket::FcntlCommand::F_SetFl, flags | nn::socket::FcntlFlag::O_NonBlock) returned -1\n", sockfd);
            ADD_FAILURE();
            goto bail;
        }
        else if (-1 == (rc = nn::socket::Fcntl(sockfd, nn::socket::FcntlCommand::F_GetFl, 0)))
        {
            NN_LOG("Error: Fcntl(%d, nn::socket::FcntlCommand::F_GetFl, 0) returned -1\n", sockfd);
            ADD_FAILURE();
            goto bail;
        }
        else if ( 0 == ( static_cast<int>(nn::socket::FcntlFlag::O_NonBlock) & rc))
        {
            NN_LOG("Error:  flags(%x) does not contain nn::socket::FcntlFlag::O_NonBlock, but it should.\n", rc);
            ADD_FAILURE();
            goto bail;
        };

        // unset nn::socket::FcntlFlag::O_NonBlock and then get the flags to ensure that we can set the socket back to blocking mode
        flags  = initialFlags;
        if ( -1 == (rc = nn::socket::Fcntl(sockfd, nn::socket::FcntlCommand::F_SetFl, flags)))
        {
            NN_LOG("Error: Fcntl(%d, nn::socket::FcntlCommand::F_SetFl, 0) returned -1\n", sockfd);
            ADD_FAILURE();
            goto bail;
        }
        else if (-1 == (rc = nn::socket::Fcntl(sockfd, nn::socket::FcntlCommand::F_GetFl, flags)))
        {
            NN_LOG("Error: Fcntl(%d, nn::socket::FcntlCommand::F_GetFl, 0) returned -1\n", sockfd);
            ADD_FAILURE();
            goto bail;
        }
        else if ( 0 != ( static_cast<int>(nn::socket::FcntlFlag::O_NonBlock) & rc))
        {
            NN_LOG("Error:  flags(%x) contains nn::socket::FcntlFlag::O_NonBlock, but should not, because we set the initial flags value.\n", rc);
            ADD_FAILURE();
            goto bail;
        };
    }

bail:
    if ( -1 != sockfd )
    {
        nn::socket::Close(sockfd);
        sockfd = -1;
    };

    if (result.IsSuccess())
    {
        result = nn::socket::Finalize();
        if( result.IsFailure() )
        {
            NN_LOG("Error: nn::socket::Finalize() failed. Err Desc: %d\n", result.GetDescription());
        };
    };
};

TEST(GetAndSetNonblockSimple, RunOnce)
{
    testGetAndSetNonblockSimple(1);
};

TEST(GetAndSetNonblockSimple, RunMany)
{
    // regression for SIGLONTD-4866
    testGetAndSetNonblockSimple(static_cast<uint8_t>(-1));
}

}} // namespace NATF::API
