﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <cstdio>
#include <nnt/nntest.h>
#include <nn/nn_Log.h>
#include <nn/init.h>
#include <nn/os.h>
#include <nn/ssl.h>
#include <nn/socket.h>

#include <Common/testCommonUtil.h>
#include <Common/testInfraInfo.h>

// ------------------------------------------------------------------------------------------------
// Parameters
// ------------------------------------------------------------------------------------------------
namespace
{
const uint32_t MsecPollTimeout = 5000;
const uint32_t MsecWaitServerReply = 1000;

SslTestCommonUtil        g_CommonUtil;
NN_ALIGNAS(4096) uint8_t g_SocketMemoryPoolBuffer[nn::socket::DefaultSocketMemoryPoolSize];
} // Un-named namespace

// ------------------------------------------------------------------------------------------------
// Utils
// ------------------------------------------------------------------------------------------------
#if defined (NN_BUILD_CONFIG_OS_WIN)
#define MY_SNPRINTF _snprintf
#else
#define MY_SNPRINTF snprintf
#endif

namespace
{
    void PrintPollEventState(nn::ssl::Connection::PollEvent pollEvent)
    {
        NN_LOG(" PollEvent ------------------\n");
        NN_LOG(" Read:   %s\n",
            ((pollEvent & nn::ssl::Connection::PollEvent::PollEvent_Read) ==
             nn::ssl::Connection::PollEvent::PollEvent_Read)?("YES"):("NO"));
        NN_LOG(" Write:  %s\n",
            ((pollEvent & nn::ssl::Connection::PollEvent::PollEvent_Write) ==
             nn::ssl::Connection::PollEvent::PollEvent_Write)?("YES"):("NO"));
        NN_LOG(" Except: %s\n",
            ((pollEvent & nn::ssl::Connection::PollEvent::PollEvent_Except) ==
             nn::ssl::Connection::PollEvent::PollEvent_Except)?("YES"):("NO"));
        NN_LOG(" ----------------------------\n");
    }
}

// ------------------------------------------------------------------------------------------------
// Tests
// ------------------------------------------------------------------------------------------------
TEST(InitTest, Success)
{
    ASSERT_TRUE(g_CommonUtil.SetupNetwork().IsSuccess());
    ASSERT_TRUE(nn::socket::Initialize(
        g_SocketMemoryPoolBuffer,
        nn::socket::DefaultSocketMemoryPoolSize,
        nn::socket::MinSocketAllocatorSize,
        nn::socket::DefaultConcurrencyLimit).IsSuccess());
    ASSERT_TRUE(nn::ssl::Initialize().IsSuccess());
}

TEST(Poll, Success)
{
    bool isSocketImported = false;
    nn::Result result;
    int socketFd = g_CommonUtil.CreateTcpSocket(
        true,
        ServerPort_Normal,
#ifndef NO_RESOLVER
        ServerName,
        0
#else
        nullptr,
        ServerIpAddress
#endif
    );
    EXPECT_TRUE(socketFd >= 0);

    nn::ssl::Context*    pSslContext = new nn::ssl::Context();;
    nn::ssl::Connection* pSslConnection = new nn::ssl::Connection();;
    ASSERT_TRUE(pSslContext != nullptr);
    ASSERT_TRUE(pSslConnection != nullptr);

    result = pSslContext->Create(nn::ssl::Context::SslVersion_Auto);
    ASSERT_TRUE(result.IsSuccess());
    result = pSslConnection->Create(pSslContext);
    ASSERT_TRUE(result.IsSuccess());

    result = pSslConnection->SetSocketDescriptor(socketFd);
    ASSERT_TRUE(result.IsSuccess());
    isSocketImported = true;
    result = pSslConnection->SetHostName(ServerName, strlen(ServerName));
    ASSERT_TRUE(result.IsSuccess());

    result = pSslConnection->SetOption(nn::ssl::Connection::OptionType_SkipDefaultVerify, true);
    ASSERT_TRUE(result.IsSuccess());

    result = pSslConnection->SetVerifyOption(nn::ssl::Connection::VerifyOption::VerifyOption_None);
    ASSERT_TRUE(result.IsSuccess());

    result = pSslConnection->DoHandshake();
    ASSERT_TRUE(result.IsSuccess());
    NN_LOG(" SSL Handshake completed.\n");

    // Poll should return immediately
    {
        nn::ssl::Connection::PollEvent pollEvent = nn::ssl::Connection::PollEvent::PollEvent_None;
        nn::ssl::Connection::PollEvent pollOutEvent = nn::ssl::Connection::PollEvent::PollEvent_None;
        pollEvent |= nn::ssl::Connection::PollEvent::PollEvent_Write;

        NN_LOG("Calling poll w/to %d for write.\n", MsecPollTimeout);
        nn::os::Tick startTick = nn::os::GetSystemTick();
        result = pSslConnection->Poll(&pollOutEvent, &pollEvent, MsecPollTimeout);
        uint64_t elapsed = nn::os::ConvertToTimeSpan(nn::os::GetSystemTick() - startTick).GetSeconds();
        EXPECT_TRUE(result.IsSuccess());
        EXPECT_TRUE(elapsed == 0);
        EXPECT_TRUE(pollOutEvent == nn::ssl::Connection::PollEvent::PollEvent_Write);
        NN_LOG("Poll returned after %d sec.\n", elapsed);
        PrintPollEventState(pollOutEvent);
    }

    // Poll should return after MsecPollTimeout
    {
        nn::ssl::Connection::PollEvent pollEvent = nn::ssl::Connection::PollEvent::PollEvent_None;
        nn::ssl::Connection::PollEvent pollOutEvent = nn::ssl::Connection::PollEvent::PollEvent_None;
        pollEvent |= nn::ssl::Connection::PollEvent::PollEvent_Read;

        NN_LOG("Calling poll w/to %d for read.\n", MsecPollTimeout);
        nn::os::Tick startTick = nn::os::GetSystemTick();
        result = pSslConnection->Poll(&pollOutEvent, &pollEvent, MsecPollTimeout);
        uint64_t elapsed = nn::os::ConvertToTimeSpan(nn::os::GetSystemTick() - startTick).GetMilliSeconds();
        EXPECT_TRUE(nn::ssl::ResultIoTimeout::Includes(result));
        EXPECT_TRUE(elapsed >= MsecPollTimeout - 5);
        EXPECT_TRUE(pollOutEvent == nn::ssl::Connection::PollEvent::PollEvent_None);
        NN_LOG("Poll returned after %d millisec.\n", elapsed);
        PrintPollEventState(pollOutEvent);
    }

    char httpReqBuff[64] = {0};
    MY_SNPRINTF(httpReqBuff, sizeof(httpReqBuff), "GET / HTTP/1.0\r\nHost: %s\r\n\r\n", ServerName);
    uint32_t httpReqBuffLen = static_cast<uint32_t>(strlen(httpReqBuff));
    int sentBytes = pSslConnection->Write(httpReqBuff, httpReqBuffLen);
    ASSERT_TRUE(sentBytes > 0);
    NN_LOG(" Sent HTTP request over SSL (%d bytes).\n", sentBytes);

    // Wait data from the server
    nn::os::SleepThread(nn::TimeSpan::FromMilliSeconds(MsecWaitServerReply));

    // Poll should return immediately
    {
        nn::ssl::Connection::PollEvent pollEvent = nn::ssl::Connection::PollEvent::PollEvent_None;
        nn::ssl::Connection::PollEvent pollOutEvent = nn::ssl::Connection::PollEvent::PollEvent_None;
        pollEvent |= nn::ssl::Connection::PollEvent::PollEvent_Read;

        NN_LOG("Calling poll w/to %d for read when there's data to read\n", MsecPollTimeout);
        nn::os::Tick startTick = nn::os::GetSystemTick();
        result = pSslConnection->Poll(&pollOutEvent, &pollEvent, MsecPollTimeout);
        uint64_t elapsed = nn::os::ConvertToTimeSpan(nn::os::GetSystemTick() - startTick).GetSeconds();
        EXPECT_TRUE(result.IsSuccess());
        EXPECT_TRUE(elapsed == 0);
        EXPECT_TRUE(pollOutEvent == nn::ssl::Connection::PollEvent::PollEvent_Read);
        NN_LOG("Poll returned after %d sec\n", elapsed);
        PrintPollEventState(pollOutEvent);
    }

    int receivedTotalBytes    = 0;
    do
    {
        nn::ssl::Connection::PollEvent pollEvent = nn::ssl::Connection::PollEvent::PollEvent_None;
        nn::ssl::Connection::PollEvent pollOutEvent = nn::ssl::Connection::PollEvent::PollEvent_None;
        pollEvent |= nn::ssl::Connection::PollEvent::PollEvent_Read;

        NN_LOG(" Calling poll w/to %d\n", MsecPollTimeout);
        result = pSslConnection->Poll(&pollOutEvent, &pollEvent, MsecPollTimeout);
        EXPECT_TRUE(result.IsSuccess());

        if((pollOutEvent & nn::ssl::Connection::PollEvent::PollEvent_Read)
           == nn::ssl::Connection::PollEvent::PollEvent_Read)
        {
            char tmpBuff[1024] = {0};
            int receivedBytes = pSslConnection->Read(tmpBuff, sizeof(tmpBuff));
            EXPECT_TRUE(receivedBytes >= 0);
            if(receivedBytes < 0)
            {
                NN_LOG(" nn::ssl::Read failed!\n");
                break;
            }
            if(receivedBytes == 0)
            {
                NN_LOG(" Connection closed by the server.\n");
                break;
            }
            receivedTotalBytes += receivedBytes;
        }
    } while(NN_STATIC_CONDITION(false));
    NN_LOG(" Received %d bytes\n", receivedTotalBytes);

    // Cleanup
    result = pSslConnection->Destroy();
    EXPECT_TRUE(result.IsSuccess());
    delete pSslConnection;

    result = pSslContext->Destroy();
    EXPECT_TRUE(result.IsSuccess());
    delete pSslContext;

    if(isSocketImported != true)
    {
        g_CommonUtil.CloseTcpSocket(socketFd);
    }
} // NOLINT(impl/function_size)

TEST(PollInNonblocking, Success)
{
    bool isSocketImported = false;
    nn::Result result;
    int socketFd = g_CommonUtil.CreateTcpSocket(
        true,
        ServerPort_Normal,
#ifndef NO_RESOLVER
        ServerName,
        0
#else
        nullptr,
        ServerIpAddress
#endif
    );
    EXPECT_TRUE(socketFd >= 0);

    nn::ssl::Context*    pSslContext = new nn::ssl::Context();;
    nn::ssl::Connection* pSslConnection = new nn::ssl::Connection();;
    ASSERT_TRUE(pSslContext != nullptr);
    ASSERT_TRUE(pSslConnection != nullptr);

    result = pSslContext->Create(nn::ssl::Context::SslVersion_Auto);
    ASSERT_TRUE(result.IsSuccess());
    result = pSslConnection->Create(pSslContext);
    ASSERT_TRUE(result.IsSuccess());

    result = pSslConnection->SetSocketDescriptor(socketFd);
    ASSERT_TRUE(result.IsSuccess());
    isSocketImported = true;
    result = pSslConnection->SetHostName(ServerName, strlen(ServerName));
    ASSERT_TRUE(result.IsSuccess());

    result = pSslConnection->SetOption(nn::ssl::Connection::OptionType_SkipDefaultVerify, true);
    ASSERT_TRUE(result.IsSuccess());
    result = pSslConnection->SetVerifyOption(nn::ssl::Connection::VerifyOption::VerifyOption_None);
    ASSERT_TRUE(result.IsSuccess());

    result = pSslConnection->DoHandshake();
    ASSERT_TRUE(result.IsSuccess());
    NN_LOG(" SSL Handshake completed.\n");

    result = pSslConnection->SetIoMode(nn::ssl::Connection::IoMode_NonBlocking);
    ASSERT_TRUE(result.IsSuccess());

    // Poll should return immediately
    {
        nn::ssl::Connection::PollEvent pollEvent = nn::ssl::Connection::PollEvent::PollEvent_None;
        nn::ssl::Connection::PollEvent pollOutEvent = nn::ssl::Connection::PollEvent::PollEvent_None;
        pollEvent |= nn::ssl::Connection::PollEvent::PollEvent_Write;

        NN_LOG("Calling poll w/to %d for write.\n", MsecPollTimeout);
        nn::os::Tick startTick = nn::os::GetSystemTick();
        result = pSslConnection->Poll(&pollOutEvent, &pollEvent, MsecPollTimeout);
        uint64_t elapsed = nn::os::ConvertToTimeSpan(nn::os::GetSystemTick() - startTick).GetSeconds();
        EXPECT_TRUE(result.IsSuccess());
        EXPECT_TRUE(elapsed == 0);
        EXPECT_TRUE(pollOutEvent == nn::ssl::Connection::PollEvent::PollEvent_Write);
        NN_LOG("Poll returned after %d sec.\n", elapsed);
        PrintPollEventState(pollOutEvent);
    }

    // Poll should return after MsecPollTimeout
    {
        nn::ssl::Connection::PollEvent pollEvent = nn::ssl::Connection::PollEvent::PollEvent_None;
        nn::ssl::Connection::PollEvent pollOutEvent = nn::ssl::Connection::PollEvent::PollEvent_None;
        pollEvent |= nn::ssl::Connection::PollEvent::PollEvent_Read;

        NN_LOG("Calling poll w/to %d for read.\n", MsecPollTimeout);
        nn::os::Tick startTick = nn::os::GetSystemTick();
        result = pSslConnection->Poll(&pollOutEvent, &pollEvent, MsecPollTimeout);
        uint64_t elapsed = nn::os::ConvertToTimeSpan(nn::os::GetSystemTick() - startTick).GetMilliSeconds();
        EXPECT_TRUE(nn::ssl::ResultIoTimeout::Includes(result));
        EXPECT_TRUE(elapsed >= MsecPollTimeout - 5);
        EXPECT_TRUE(pollOutEvent == nn::ssl::Connection::PollEvent::PollEvent_None);
        NN_LOG("Poll returned after %d millisec.\n", elapsed);
        PrintPollEventState(pollOutEvent);
    }

    char httpReqBuff[64] = {0};
    MY_SNPRINTF(httpReqBuff, sizeof(httpReqBuff), "GET / HTTP/1.0\r\nHost: %s\r\n\r\n", ServerName);
    uint32_t httpReqBuffLen = static_cast<uint32_t>(strlen(httpReqBuff));
    int sentBytes = pSslConnection->Write(httpReqBuff, httpReqBuffLen);
    ASSERT_TRUE(sentBytes > 0);
    NN_LOG(" Sent HTTP request over SSL (%d bytes).\n", sentBytes);

    // Wait data from the server
    nn::os::SleepThread(nn::TimeSpan::FromMilliSeconds(MsecWaitServerReply));

    // Poll should return immediately
    {
        nn::ssl::Connection::PollEvent pollEvent = nn::ssl::Connection::PollEvent::PollEvent_None;
        nn::ssl::Connection::PollEvent pollOutEvent = nn::ssl::Connection::PollEvent::PollEvent_None;
        pollEvent |= nn::ssl::Connection::PollEvent::PollEvent_Read;

        NN_LOG("Calling poll w/to %d for read when there's data to read\n", MsecPollTimeout);
        nn::os::Tick startTick = nn::os::GetSystemTick();
        result = pSslConnection->Poll(&pollOutEvent, &pollEvent, MsecPollTimeout);
        uint64_t elapsed = nn::os::ConvertToTimeSpan(nn::os::GetSystemTick() - startTick).GetSeconds();
        EXPECT_TRUE(result.IsSuccess());
        EXPECT_TRUE(elapsed == 0);
        EXPECT_TRUE(pollOutEvent == nn::ssl::Connection::PollEvent::PollEvent_Read);
        NN_LOG("Poll returned after %d sec\n", elapsed);
        PrintPollEventState(pollOutEvent);
    }

    int receivedTotalBytes    = 0;
    do
    {
        nn::ssl::Connection::PollEvent pollEvent = nn::ssl::Connection::PollEvent::PollEvent_None;
        nn::ssl::Connection::PollEvent pollOutEvent = nn::ssl::Connection::PollEvent::PollEvent_None;
        pollEvent |= nn::ssl::Connection::PollEvent::PollEvent_Read;

        NN_LOG(" Calling poll w/to %d\n", MsecPollTimeout);
        result = pSslConnection->Poll(&pollOutEvent, &pollEvent, MsecPollTimeout);
        EXPECT_TRUE(result.IsSuccess());

        if((pollOutEvent & nn::ssl::Connection::PollEvent::PollEvent_Read)
           == nn::ssl::Connection::PollEvent::PollEvent_Read)
        {
            char tmpBuff[1024] = {0};
            int receivedBytes = pSslConnection->Read(tmpBuff, sizeof(tmpBuff));
            EXPECT_TRUE(receivedBytes >= 0);
            if(receivedBytes < 0)
            {
                NN_LOG(" nn::ssl::Read failed!\n");
                break;
            }
            if(receivedBytes == 0)
            {
                NN_LOG(" Connection closed by the server.\n");
                break;
            }
            receivedTotalBytes += receivedBytes;
        }
    } while(NN_STATIC_CONDITION(false));
    NN_LOG(" Received %d bytes\n", receivedTotalBytes);

    // Cleanup
    result = pSslConnection->Destroy();
    EXPECT_TRUE(result.IsSuccess());
    delete pSslConnection;

    result = pSslContext->Destroy();
    EXPECT_TRUE(result.IsSuccess());
    delete pSslContext;

    if(isSocketImported != true)
    {
        g_CommonUtil.CloseTcpSocket(socketFd);
    }
} // NOLINT(impl/function_size)

TEST(FinalizeTest, Success)
{
    EXPECT_TRUE(nn::ssl::Finalize().IsSuccess());
    nn::socket::Finalize();
    g_CommonUtil.FinalizeNetwork();
}
