﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <cstdio>
#include <cstring>

#include <nn/nn_Log.h>
#include <nn/nn_Assert.h>
#include <nn/init.h>
#include <nn/os.h>

#include <nn/ssl.h>
#include <nn/ssl/ssl_Api.debug.h>
#include <nn/socket.h>

#include "ExecuterTestsBase.h"

// ------------------------------------------------------------------------------------------------
// ExecuterTestsBase
// ------------------------------------------------------------------------------------------------
ExecuterTestsBase::ExecuterTestsBase() : m_IsSucceeded(false), m_IsVerbose(false)
{
}

bool ExecuterTestsBase::IsSucceeded()
{
    return m_IsSucceeded;
}

void ExecuterTestsBase::SetResult(bool isSuccessed)
{
    m_IsSucceeded = isSuccessed;
}

void ExecuterTestsBase::EnableVerbose()
{
    m_IsVerbose = true;
}

bool ExecuterTestsBase::IsVerbose()
{
    return m_IsVerbose;
}

void ExecuterTestsBase::PrintResult()
{
    NN_LOG(" ****************\n");
    NN_LOG(" Result: %s\n", (m_IsSucceeded == true)?"Success":"Fail");
    NN_LOG(" ****************\n");
}

int ExecuterTestsBase::CreateTcpSocket(
    bool bEstablishConn,
    uint16_t portNumber,
    const char* pInHostName,
    uint32_t ipAddress)
{
    int                tcpSocket;
    nn::socket::InAddr     inetAddr;
    nn::socket::SockAddrIn serverAddr;

    memset(&inetAddr, 0x00, sizeof(inetAddr));

    if(pInHostName == nullptr && ipAddress == 0)
    {
        NN_LOG(" Host name and IP address was not passed.\n");
        return -1;
    }

    tcpSocket = nn::socket::Socket(nn::socket::Family::Af_Inet, nn::socket::Type::Sock_Stream, nn::socket::Protocol::IpProto_Tcp);
    if(tcpSocket < 0)
    {
        NN_LOG(" Failed to create TCP socket (errno: %d)\n", errno);
        return -1;
    }
    TEST_CASE_VERBOSE_LOG(" Created TCP socket (sockfd: %d).\n", tcpSocket);

    if (bEstablishConn != true)
    {
        return tcpSocket;
    }

    if(ipAddress == 0)
    {
        TEST_CASE_VERBOSE_LOG(" Resolving %s\n", pInHostName);
        nn::socket::HostEnt *pHostEnt = nn::socket::GetHostEntByName(pInHostName);
        if(pHostEnt == nullptr)
        {
            NN_LOG(" Failed to resolve host name (errno:%d)\n", errno);
            nn::socket::Close(tcpSocket);
            return -1;
        }

        // Just pick the first one
        memcpy(&inetAddr, pHostEnt->h_addr_list[0], sizeof(nn::socket::InAddr));
        serverAddr.sin_addr.S_addr = inetAddr.S_addr;
    }
    else
    {
        TEST_CASE_VERBOSE_LOG(" Use 0x%x for server IP.\n", ipAddress);
        serverAddr.sin_addr.S_addr = ipAddress;
    }

    serverAddr.sin_family = nn::socket::Family::Af_Inet;
    serverAddr.sin_port   = nn::socket::InetHtons(portNumber);

    int rval = nn::socket::Connect(tcpSocket, (nn::socket::SockAddr*)&serverAddr, sizeof(serverAddr));
    if(rval < 0)
    {
        NN_LOG(" Failed to establish TCP connection (errno: %d).\n", errno);
        nn::socket::Close(tcpSocket);
        return -1;
    }

    TEST_CASE_VERBOSE_LOG(" Established TCP connection (addr:0x%x on port :%d).\n",
        serverAddr.sin_addr.S_addr, nn::socket::InetNtohs(serverAddr.sin_port));

    return tcpSocket;
}

nn::Result ExecuterTestsBase::PerformHandshake(
    const char* pInHostName,
    nn::ssl::Context    *pSslContext,
    nn::ssl::Connection *pSslConnection,
    uint16_t portNumber,
    nn::ssl::Connection::VerifyOption verifyOption,
    nn::ssl::Connection::SessionCacheMode sessionCacheMode,
    nn::ssl::Connection::IoMode ioMode,
    bool contextsReady,
    bool keepConnection,
    bool returnUponHandshakeFailure,
    bool getServerCert)
{
    nn::Result result;

    do
    {
        if (contextsReady == false)
        {
            result = pSslContext->Create(nn::ssl::Context::SslVersion_Auto);
            if (result.IsFailure())
            {
                TEST_CASE_VERBOSE_LOG(" Failed to create a SSL context.\n");
                break;
            }

            result = pSslConnection->Create(pSslContext);
            if (result.IsFailure())
            {
                TEST_CASE_VERBOSE_LOG(" Failed to create a SSL connection.\n");
                break;
            }
        }

        int socketFd = ExecuterTestsBase::CreateTcpSocket(true, portNumber, pInHostName, 0);
        if (socketFd < 0)
        {
            TEST_CASE_VERBOSE_LOG(" Failed to create a TCP socket.\n");
            break;
        }

        result = pSslConnection->SetSocketDescriptor(socketFd);
        if (result.IsFailure())
        {
            TEST_CASE_VERBOSE_LOG(" Failed to set a socket descriptor.\n");
            break;
        }

        result = pSslConnection->SetIoMode(ioMode);
        if (result.IsFailure())
        {
            TEST_CASE_VERBOSE_LOG(" Failed to set IO mode.\n");
            break;
        }

        uint32_t hostNameLen = static_cast<uint32_t>(strlen(pInHostName));
        result = pSslConnection->SetHostName(pInHostName, hostNameLen);
        if (result.IsFailure())
        {
            TEST_CASE_VERBOSE_LOG(" Failed to set a host name.\n");
            break;
        }

        result = pSslConnection->SetOption(nn::ssl::Connection::OptionType_SkipDefaultVerify, true);
        if (result.IsFailure())
        {
            TEST_CASE_VERBOSE_LOG(" Failed to set a option (OptionType_SkipDefaultVerify).\n");
            break;
        }

        result = pSslConnection->SetVerifyOption(verifyOption);
        if (result.IsFailure())
        {
            TEST_CASE_VERBOSE_LOG(" Failed to set a verify option.\n");
            break;
        }

        result = pSslConnection->SetSessionCacheMode(sessionCacheMode);
        if (result.IsFailure())
        {
            TEST_CASE_VERBOSE_LOG(" Failed to set a session cache mode.\n");
            break;
        }

        char* pServerCertBuffer = nullptr;
        do
        {

            if (getServerCert)
            {
                const uint32_t serverCertBufferSize = 1024 * 32;
                pServerCertBuffer = new char[serverCertBufferSize];
                if (pServerCertBuffer == nullptr)
                {
                    NN_LOG(" Failed to allocate memory for server cert buffer.\n");
                    break;
                }

                result = pSslConnection->SetServerCertBuffer(pServerCertBuffer, serverCertBufferSize);
                if (result.IsFailure())
                {
                    NN_LOG(" Failed to set server cert buffer.\n");
                    break;
                }

                uint32_t serverCertSize = 0;
                result = pSslConnection->DoHandshake(&serverCertSize, nullptr);
                if (result.IsSuccess())
                {
                    TEST_CASE_VERBOSE_LOG(" Read server cert (size:%d bytes)\n", serverCertSize);
                }
            }
            else
            {
                result = pSslConnection->DoHandshake();
            }

            if (result.IsFailure())
            {
                if (nn::ssl::ResultIoWouldBlock::Includes(result))
                {
                    continue;
                }
                else
                {
                    nn::Result verifyError;
                    pSslConnection->GetVerifyCertError(&verifyError);
                    NN_LOG(" [WARN] DoHandshake failed (%d)(VerifyError:%d)\n",
                        result.GetDescription(),
                        verifyError.GetDescription());
                    if (returnUponHandshakeFailure == true)
                    {
                        break;
                    }
                }
            }
            else
            {
                TEST_CASE_VERBOSE_LOG(" Handshake done.\n");
                break;
            }
        } while (NN_STATIC_CONDITION(true));

        if (pServerCertBuffer != nullptr)
        {
            delete[] pServerCertBuffer;
        }
    } while (NN_STATIC_CONDITION(false));

    if (keepConnection == false)
    {
        nn::ssl::SslConnectionId connId;
        pSslConnection->GetConnectionId(&connId);
        if (connId != 0)
        {
            result = pSslConnection->Destroy();
            if (result.IsFailure())
            {
                TEST_CASE_VERBOSE_LOG(" Failed to destroy a SSL connection.\n");
            }
        }

        nn::ssl::SslContextId ctxId;
        pSslContext->GetContextId(&ctxId);
        if (ctxId != 0)
        {
            result = pSslContext->Destroy();
            if (result.IsFailure())
            {
                TEST_CASE_VERBOSE_LOG(" Failed to destroy a SSL context.\n");
            }
        }
    }

    return result;
} // NOLINT(impl/function_size)

void ExecuterTestsBase::StartHeapTrack(nn::ssl::Debug::HeapTrackPoint* pInTracker)
{
    nn::ssl::Debug::Input inputData;
    inputData.pBuffer     = reinterpret_cast<const char*>(pInTracker);
    inputData.bufferSize  = sizeof(nn::ssl::Debug::HeapTrackPoint);

    if (nn::ssl::Debug::Ioctl(nullptr, &inputData, nn::ssl::Debug::IoctlCommand_StartHeapTrack).IsFailure())
    {
        NN_LOG(" [StartHeapTrack] Failed to set start heap track point.\n");
    }
}

void ExecuterTestsBase::EndHeapTrack(nn::ssl::Debug::HeapTrackPoint* pInTracker)
{
    nn::ssl::Debug::TrackStats stats;
    nn::ssl::Debug::Input      inputData;
    nn::ssl::Debug::Output     outputData;

    inputData.pBuffer     = reinterpret_cast<const char*>(pInTracker);
    inputData.bufferSize  = sizeof(nn::ssl::Debug::HeapTrackPoint);
    outputData.pBuffer    = reinterpret_cast<char*>(&stats);
    outputData.bufferSize = sizeof(nn::ssl::Debug::TrackStats);

    if (nn::ssl::Debug::Ioctl(&outputData, &inputData, nn::ssl::Debug::IoctlCommand_EndHeapTrack).IsFailure())
    {
        NN_LOG(" [EndHeapTrack] Failed to find end heap track point.\n");
    }
    else
    {
        NN_LOG(" [EndHeapTrack] elapsed:%dmsec current heap:%d(%dKB) delta:%d(%dKB)\n",
            stats.elapsedMsec,
            stats.curHeapSpace,
            stats.curHeapSpace / 1024,
            stats.delta,
            stats.delta / 1024);
    }
}
