﻿/*--------------------------------------------------------------------------------*
  Copyright (C)Nintendo All rights reserved.

  These coded instructions, statements, and computer programs contain proprietary
  information of Nintendo and/or its licensed developers and are protected by
  national and international copyright laws. They may not be disclosed to third
  parties or copied or duplicated in any form, in whole or in part, without the
  prior written consent of Nintendo.

  The content herein is highly confidential and should be handled accordingly.
 *--------------------------------------------------------------------------------*/
#include <cstdio>
#include <nnt/nntest.h>
#include <nn/nn_Log.h>
#include <nn/init.h>
#include <nn/os.h>
#include <nn/nifm.h>
#include <nn/ssl.h>
#include <nn/socket.h>

#include <Common/testCommonUtil.h>
#include <Common/testInfraInfo.h>
#include <Common/testServerPki.h>
#include <Common/testClientPki.h>
#include <Common/testClientPkiNoPwd.h>

// ------------------------------------------------------------------------------------------------
// [Description]
// This test program verifies that no default heap is used by the SSL library.
// It fills out default heap with the specific pettern. Just before starting SSL
// stuff, it first gets how much was overwritten (because the amount changed before
// starting the SSL library doesn't matter). After the SSL library is finalized,
// it gets how much was overwitten again. This amount needs to be equal to the one
// obtained before initializing the SSL library.
//
// This test calls all the SSL APIs to confirm no API uses the default heap.
//
// Note that this test is supposed to succeed only on SDEV.
// ------------------------------------------------------------------------------------------------

// ------------------------------------------------------------------------------------------------
// Utils
// ------------------------------------------------------------------------------------------------
namespace
{
    const uint32_t           AddressInitPattern = 0xCAFEBEEF;
    const size_t             MemoryHeapSize     = 1024 * 1024 * 2;

    SslTestCommonUtil        g_CommonUtil;
    NN_ALIGNAS(4096) uint8_t g_SocketMemoryPoolBuffer[nn::socket::DefaultSocketMemoryPoolSize];
    uintptr_t                g_HeapBase;
    uint32_t                 g_BaseErrCount;

    void InitializeMemoryChunk(char* ptr, size_t len)
    {
        uint32_t* pCurr  = reinterpret_cast<uint32_t*>(ptr);
        size_t    tmpLen = len / (sizeof(uint32_t) / sizeof(char));

        for (uint32_t i = 0; i < tmpLen; i++)
        {
            pCurr[i] = AddressInitPattern;
        }

        NN_LOG("!!! Initialized %d bytes (sizeof(char):%d sizeof(uint32_t):%d)\n",
            tmpLen * sizeof(uint32_t), sizeof(char), sizeof(uint32_t));
    }

    uint32_t GetErrCount(char* ptr, size_t len)
    {
        uint32_t* pCurr    = reinterpret_cast<uint32_t*>(ptr);
        size_t    tmpLen   = len / (sizeof(uint32_t) / sizeof(char));
        uint32_t  errCount = 0;

        for (uint32_t i = 0; i < tmpLen; i++)
        {
            if (pCurr[i] != AddressInitPattern)
            {
                errCount++;
            }
        }
        return errCount;
    }

    void VerifyMemoryCnunk(char* ptr, size_t len, uint32_t baseErrCount)
    {
        uint32_t  errCount = GetErrCount(ptr, len);
        if (errCount != baseErrCount)
        {
            NN_LOG("Detected %d bytes got overwritten in the default heap\n",
                errCount - baseErrCount);
        }
        ASSERT_TRUE(errCount == baseErrCount);

        NN_LOG("!!! Verified memory chunk.\n");
    }

} // Un-named namespace

extern "C" void nninitStartup()
{
    auto result = nn::os::SetMemoryHeapSize(MemoryHeapSize);
    ASSERT_TRUE( result.IsSuccess() );

    g_HeapBase = 0;
    result = nn::os::AllocateMemoryBlock(&g_HeapBase, MemoryHeapSize);
    ASSERT_TRUE(result.IsSuccess());

    InitializeMemoryChunk(reinterpret_cast<char *>(g_HeapBase), MemoryHeapSize);

    nn::init::InitializeAllocator( reinterpret_cast<void*>(g_HeapBase), MemoryHeapSize );
}

// ------------------------------------------------------------------------------------------------
// Tests
// ------------------------------------------------------------------------------------------------
TEST(InitTest, Success)
{
    ASSERT_TRUE(g_CommonUtil.SetupNetwork().IsSuccess());
    ASSERT_TRUE(nn::socket::Initialize(
        g_SocketMemoryPoolBuffer,
        nn::socket::DefaultSocketMemoryPoolSize,
        nn::socket::MinSocketAllocatorSize,
        nn::socket::DefaultConcurrencyLimit).IsSuccess());
}

TEST(RunTest, Success)
{
    int socketFd = g_CommonUtil.CreateTcpSocket(
        true,
        ServerPort_ClientAuth,
#ifndef NO_RESOLVER
        ServerName,
        0
#else
        nullptr,
        ServerIpAddress
#endif
    );

    // Get base error count here. Anything already allocated at this point is not something to
    // deal with the SSL library. Obtained base error count needs to be same when error count
    // is obtained by VerifyMemoryCnunk() at the bottom of this function
    g_BaseErrCount = GetErrCount(reinterpret_cast<char *>(g_HeapBase), MemoryHeapSize);

    // --------------------------------------------------------------------------------------------
    // Start nn::ssl
    // --------------------------------------------------------------------------------------------
    nn::Result               result;
    nn::ssl::Context         sslContext;
    nn::ssl::Connection      sslConnection;
    nn::ssl::Context*        pSslContext = &sslContext;
    nn::ssl::Connection*     pSslConnection = &sslConnection;
    nn::ssl::SslConnectionId connectionId;
    nn::ssl::SslContextId    contextId;
    nn::ssl::CertStoreId     certStoreId;
    nn::ssl::CertStoreId     clientCertStoreId;

    result = nn::ssl::Initialize();
    EXPECT_TRUE(result.IsSuccess());

    // --------------------------------------------------------------------------------------------
    // nn::ssl::Context
    // --------------------------------------------------------------------------------------------
    result = pSslContext->Create(nn::ssl::Context::SslVersion_Auto);
    EXPECT_TRUE(result.IsSuccess());

    result = pSslContext->GetContextId(&contextId);
    EXPECT_TRUE(result.IsSuccess());

    result = pSslContext->ImportServerPki(
            &certStoreId,
            g_pTestCaCert,
            sizeof(g_pTestCaCert),
            nn::ssl::CertificateFormat_Pem);
    EXPECT_TRUE(result.IsSuccess());

    result = pSslContext->ImportClientPki(
        &clientCertStoreId,
        (const char *)g_pTestClientPki,
        (const char *)g_pTestClientPkiPassword,
        g_pTestClientPkiSize,
        g_TestClientPkiPasswordLength);
    EXPECT_TRUE(result.IsSuccess());

    // --------------------------------------------------------------------------------------------
    // nn::ssl::Connection
    // --------------------------------------------------------------------------------------------
    result = pSslConnection->Create(pSslContext);
    EXPECT_TRUE(result.IsSuccess());

    // ----------------------------------------------------------------------------------------
    // Set APIs
    // ----------------------------------------------------------------------------------------
    result = pSslConnection->SetSocketDescriptor(socketFd);
    EXPECT_TRUE(result.IsSuccess());

    result = pSslConnection->SetIoMode(nn::ssl::Connection::IoMode_Blocking);
    EXPECT_TRUE(result.IsSuccess());

    result = pSslConnection->SetSessionCacheMode(nn::ssl::Connection::SessionCacheMode_SessionId);
    EXPECT_TRUE(result.IsSuccess());

    result = pSslConnection->SetRenegotiationMode(nn::ssl::Connection::RenegotiationMode_Secure);
    EXPECT_TRUE(result.IsSuccess());

    result = pSslConnection->SetVerifyOption(nn::ssl::Connection::VerifyOption::VerifyOption_All);
    EXPECT_TRUE(result.IsSuccess());

    result = pSslConnection->SetHostName(ServerName, strlen(ServerName));
    EXPECT_TRUE(result.IsSuccess());

    char certBuffer[1024 * 16];
    result = pSslConnection->SetServerCertBuffer(certBuffer, sizeof(certBuffer));
    EXPECT_TRUE(result.IsSuccess());

    // ----------------------------------------------------------------------------------------
    // Get APIs
    // ----------------------------------------------------------------------------------------
    result = pSslConnection->GetContextId(&contextId);
    EXPECT_TRUE(result.IsSuccess());

    result = pSslConnection->GetConnectionId(&connectionId);
    EXPECT_TRUE(result.IsSuccess());

    int tmpSock;
    result = pSslConnection->GetSocketDescriptor(&tmpSock);
    EXPECT_TRUE(result.IsSuccess());

    nn::ssl::Connection::IoMode tmpIoMode;
    result = pSslConnection->GetIoMode(&tmpIoMode);
    EXPECT_TRUE(result.IsSuccess());

    nn::ssl::Connection::SessionCacheMode tmpSessionCacheMode;
    result = pSslConnection->GetSessionCacheMode(&tmpSessionCacheMode);
    EXPECT_TRUE(result.IsSuccess());

    nn::ssl::Connection::RenegotiationMode tmpRenegoMode;
    result = pSslConnection->GetRenegotiationMode(&tmpRenegoMode);
    EXPECT_TRUE(result.IsSuccess());

    nn::ssl::Connection::VerifyOption tmpVerifyOption;
    result = pSslConnection->GetVerifyOption(&tmpVerifyOption);
    EXPECT_TRUE(result.IsSuccess());

    char tmpNameString[256];
    uint32_t tmpNameLen;
    result = pSslConnection->GetHostName(tmpNameString, &tmpNameLen, sizeof(tmpNameString));
    EXPECT_TRUE(result.IsSuccess());

    // ----------------------------------------------------------------------------------------
    // DoHandshake
    // ----------------------------------------------------------------------------------------
#if 0
    result = pSslConnection->DoHandshake();
    EXPECT_TRUE(result.IsSuccess());
#else
    uint32_t tmpCertSize;
    uint32_t tmpCertCount;
    result = pSslConnection->DoHandshake(&tmpCertSize, &tmpCertCount);
    EXPECT_TRUE(result.IsSuccess());
#endif
    result = pSslConnection->FlushSessionCache();
    EXPECT_TRUE(result.IsSuccess());

    uint32_t tmpNeededSrvCertBuffSize;
    result = pSslConnection->GetNeededServerCertBufferSize(&tmpNeededSrvCertBuffSize);
    EXPECT_TRUE(result.IsSuccess());

    nn::Result tmpVerifyResult;
    result = pSslConnection->GetVerifyCertError(&tmpVerifyResult);
    EXPECT_TRUE(result.IsSuccess());

    // ----------------------------------------------------------------------------------------
    // Read/Write
    // ----------------------------------------------------------------------------------------
    char httpReqBuff[128] = {0};
    sprintf(httpReqBuff, "GET / HTTP/1.0\r\nHost: %s\r\n\r\n", ServerName);

    int sentBytes;
    result = pSslConnection->Write(httpReqBuff, &sentBytes, (uint32_t)strlen(httpReqBuff));
    EXPECT_TRUE(result.IsSuccess());

    nn::ssl::Connection::PollEvent pollEvent = nn::ssl::Connection::PollEvent::PollEvent_None;
    nn::ssl::Connection::PollEvent pollOutEvent = nn::ssl::Connection::PollEvent::PollEvent_None;
    pollEvent |= nn::ssl::Connection::PollEvent::PollEvent_Read;
    result = pSslConnection->Poll(&pollOutEvent, &pollEvent, 5000);
    EXPECT_TRUE(result.IsSuccess());

    int tmpPending;
    result = pSslConnection->Pending(&tmpPending);
    EXPECT_TRUE(result.IsSuccess());

    char readBuff[1024];
    int  tmpReadSize;
    result = pSslConnection->Peek(readBuff, &tmpReadSize, sizeof(readBuff));
    EXPECT_TRUE(result.IsSuccess());

    do
    {
        result = pSslConnection->Read(readBuff, &tmpReadSize, sizeof(readBuff));
    } while ((tmpReadSize != 0) || (result.IsFailure() == true));

    // --------------------------------------------------------------------------------------------
    // Cleanup
    // --------------------------------------------------------------------------------------------
    result = pSslConnection->Destroy();
    EXPECT_TRUE(result.IsSuccess());

    result = pSslContext->RemovePki(certStoreId);
    EXPECT_TRUE(result.IsSuccess());

    result = pSslContext->RemovePki(clientCertStoreId);
    EXPECT_TRUE(result.IsSuccess());

    result = pSslContext->Destroy();
    EXPECT_TRUE(result.IsSuccess());

    // --------------------------------------------------------------------------------------------
    // End nn::ssl
    // --------------------------------------------------------------------------------------------
    result = nn::ssl::Finalize();
    EXPECT_TRUE(result.IsSuccess());

    // Verify how much is overwritten in the default heap. If the error count is the same as the
    // one obtained by GetBaseErrCount() above, that means nothing is allocated by the SSL library.
    VerifyMemoryCnunk(reinterpret_cast<char *>(g_HeapBase), MemoryHeapSize, g_BaseErrCount);
} // NOLINT(impl/function_size)

TEST(FinalizeTest, Success)
{
    nn::socket::Finalize();
    g_CommonUtil.FinalizeNetwork();
}
